package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.BaseKernel;

/* loaded from: input_file:com/omega/engine/nn/layer/RouteLayer.class */
public class RouteLayer extends Layer {
    private Layer[] layers;
    private BaseKernel kernel;
    private int groups;
    private int groupId;

    public RouteLayer(Layer[] layerArr) {
        this.groups = 1;
        this.groupId = 0;
        this.layers = layerArr;
        Layer layer = layerArr[0];
        this.oHeight = layer.oHeight;
        this.oWidth = layer.oWidth;
        for (Layer layer2 : layerArr) {
            if (layer2.oHeight != this.oHeight || layer2.oWidth != this.oWidth) {
                throw new RuntimeException("input size must be all same in the route layer.");
            }
            this.oChannel += layer2.oChannel;
        }
    }

    public RouteLayer(Layer[] layerArr, int i, int i2) {
        this.groups = 1;
        this.groupId = 0;
        this.groups = i;
        this.groupId = i2;
        this.layers = layerArr;
        Layer layer = layerArr[0];
        this.oHeight = layer.oHeight;
        this.oWidth = layer.oWidth;
        for (Layer layer2 : layerArr) {
            if (layer2.oHeight != this.oHeight || layer2.oWidth != this.oWidth) {
                throw new RuntimeException("input size must be all same in the route layer.");
            }
            this.oChannel += layer2.oChannel;
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        if (this.output == null || this.output.number != this.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
        if (this.kernel == null) {
            this.kernel = new BaseKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        for (Layer layer : this.layers) {
            if (layer.cache_delta == null || layer.cache_delta.number != this.number) {
                layer.cache_delta = new Tensor(this.number, layer.oChannel, this.oHeight, this.oWidth, true);
            }
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            Tensor tensor = this.layers[i2].output;
            int onceSize = tensor.getOnceSize() / this.groups;
            for (int i3 = 0; i3 < this.number; i3++) {
                this.kernel.copy_gpu(tensor, this.output, onceSize, (i3 * tensor.getOnceSize()) + (onceSize * this.groupId), 1, i + (i3 * this.output.getOnceSize()), 1);
            }
            i += onceSize;
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            Tensor tensor = this.layers[i2].cache_delta;
            int onceSize = tensor.getOnceSize() / this.groups;
            for (int i3 = 0; i3 < this.number; i3++) {
                this.kernel.axpy_gpu(this.delta, tensor, onceSize, 1.0f, i + (i3 * this.delta.getOnceSize()), 1, (i3 * tensor.getOnceSize()) + (onceSize * this.groupId), 1);
            }
            i += onceSize;
        }
    }

    public static void main(String[] strArr) {
        int i = 3 + 2;
        float[] order = MatrixUtils.order(2 * 3 * 4 * 4, 1, 1);
        float[] order2 = MatrixUtils.order(2 * 2 * 4 * 4, 1, 1);
        float[] order3 = RandomUtils.order(2 * i * 4 * 4, 1.0f, 1.0f);
        Tensor[] tensorArr = {new Tensor(2, 3, 4, 4, order, true), new Tensor(2, 2, 4, 4, order2, true)};
        Tensor tensor = new Tensor(2, i, 4, 4, true);
        Tensor tensor2 = new Tensor(2, i, 4, 4, order3, true);
        Tensor[] tensorArr2 = {new Tensor(2, 3, 4, 4, true), new Tensor(2, 2, 4, 4, true)};
        BaseKernel baseKernel = new BaseKernel();
        testForward(tensorArr, tensor, baseKernel);
        tensor.showDM();
        testBackward(tensorArr2, tensor2, baseKernel);
        tensor2.showDM();
        for (Tensor tensor3 : tensorArr2) {
            tensor3.showDM();
        }
    }

    public static void testForward(Tensor[] tensorArr, Tensor tensor, BaseKernel baseKernel) {
        int i = 0;
        for (Tensor tensor2 : tensorArr) {
            for (int i2 = 0; i2 < tensor.number; i2++) {
                baseKernel.copy_gpu(tensor2, tensor, tensor2.getOnceSize(), i2 * tensor2.getOnceSize(), 1, i + (i2 * tensor.getOnceSize()), 1);
            }
            i += tensor2.getOnceSize();
        }
    }

    public static void testBackward(Tensor[] tensorArr, Tensor tensor, BaseKernel baseKernel) {
        int i = 0;
        for (Tensor tensor2 : tensorArr) {
            for (int i2 = 0; i2 < tensor.number; i2++) {
                baseKernel.copy_gpu(tensor, tensor2, tensor2.getOnceSize(), i + (i2 * tensor.getOnceSize()), 1, i2 * tensor2.getOnceSize(), 1);
            }
            i += tensor2.getOnceSize();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.route;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
