package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.active.ActiveFunctionLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.gpu.BasicBlockKernel;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/BasicBlockLayer.class */
public class BasicBlockLayer extends Layer {
    private BasicBlockKernel kernel;
    private ConvolutionLayer identityConv;
    private BNLayer identityBN;
    private ConvolutionLayer conv1;
    private BNLayer bn1;
    private ActiveFunctionLayer a1;
    private ConvolutionLayer conv2;
    private BNLayer bn2;
    private int kHeight = 3;
    private int kWidth = 3;
    private int padding = 1;
    private int fisrtLayerStride;
    private boolean downsample;
    private BaseKernel baseKernel;
    private Tensor cache_delta;

    public BasicBlockLayer(int i, int i2, int i3, int i4, int i5, Network network) {
        this.fisrtLayerStride = 2;
        this.downsample = false;
        this.network = network;
        this.channel = i;
        this.oChannel = i2;
        this.height = i3;
        this.width = i4;
        this.fisrtLayerStride = i5;
        if (this.fisrtLayerStride == 1) {
            this.oHeight = i3;
            this.oWidth = i4;
        } else {
            this.oHeight = (((i3 + (this.padding * 2)) - this.kHeight) / i5) + 1;
            this.oWidth = (((i4 + (this.padding * 2)) - this.kWidth) / i5) + 1;
        }
        if (i != i2) {
            this.downsample = true;
        }
        this.kernel = new BasicBlockKernel();
        this.baseKernel = new BaseKernel();
        initLayers();
    }

    public void initLayers() {
        this.conv1 = new ConvolutionLayer(this.channel, this.oChannel, this.width, this.height, 3, 3, 1, this.fisrtLayerStride, false, this.network);
        this.conv1.setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
        this.conv1.paramsInit = ParamsInit.relu;
        this.bn1 = new BNLayer(this.conv1);
        this.a1 = new ReluLayer(this.bn1);
        this.conv2 = new ConvolutionLayer(this.conv1.oChannel, this.oChannel, this.conv1.oWidth, this.conv1.oHeight, 3, 3, 1, 1, false, this.network);
        this.conv2.setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
        this.conv2.paramsInit = ParamsInit.relu;
        this.bn2 = new BNLayer(this.conv2);
        if (this.downsample) {
            this.identityConv = new ConvolutionLayer(this.channel, this.oChannel, this.width, this.height, 1, 1, 0, this.fisrtLayerStride, false, this.network);
            this.identityConv.setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
            this.identityConv.paramsInit = ParamsInit.relu;
            this.identityBN = new BNLayer(this.identityConv);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        if (this.output == null || this.output.number != this.network.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.diff == null || this.conv1.number != this.conv1.diff.number) {
            this.conv1.initBack();
            this.diff = this.conv1.diff;
            this.cache_delta = new Tensor(this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.conv1.forward(this.input);
        this.bn1.forward(this.conv1.output);
        this.a1.forward(this.bn1.output);
        this.conv2.forward(this.a1.output);
        this.bn2.forward(this.conv2.output);
        if (!this.downsample) {
            this.kernel.add(this.input, this.bn2.output, this.output);
            return;
        }
        this.identityConv.forward(this.input);
        this.identityBN.forward(this.identityConv.output);
        this.kernel.add(this.identityBN.output, this.bn2.output, this.output);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.baseKernel.copy_gpu(this.delta, this.cache_delta, this.delta.getDataLength(), 1, 1);
        this.bn2.back(this.delta);
        this.conv2.back(this.bn2.diff);
        this.a1.back(this.conv2.diff);
        this.bn1.back(this.a1.diff);
        this.conv1.back(this.bn1.diff);
        if (!this.downsample) {
            this.kernel.add(this.conv1.diff, this.cache_delta, this.diff);
            return;
        }
        this.identityBN.back(this.cache_delta);
        this.identityConv.back(this.identityBN.diff);
        this.kernel.add(this.conv1.diff, this.identityConv.diff, this.diff);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.conv1.update();
        this.bn1.update();
        this.conv2.update();
        this.bn2.update();
        if (this.downsample) {
            this.identityBN.update();
            this.identityConv.update();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.block;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
