package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.active.ActiveType;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.active.ActiveFunctionLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.active.SiLULayer;
import com.omega.engine.nn.layer.active.SigmodLayer;
import com.omega.engine.nn.layer.active.TanhLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/CBLLayer.class */
public class CBLLayer extends Layer {
    private int kHeight;
    private int kWidth;
    private int padding;
    private int stride;
    private ConvolutionLayer convLayer;
    private BNLayer bnLayer;
    private ActiveFunctionLayer activeLayer;
    private ActiveType activeType;
    private BaseKernel baseKernel;

    public CBLLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, ActiveType activeType, Network network) {
        this.kHeight = 3;
        this.kWidth = 3;
        this.padding = 1;
        this.stride = 2;
        this.network = network;
        this.activeType = activeType;
        this.channel = i;
        this.oChannel = i2;
        this.height = i3;
        this.width = i4;
        this.kHeight = i5;
        this.kWidth = i6;
        this.stride = i7;
        this.padding = i8;
        this.oHeight = (((i3 + (i8 * 2)) - i5) / i7) + 1;
        this.oWidth = (((i4 + (i8 * 2)) - i6) / i7) + 1;
        initLayers();
    }

    public CBLLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, String str, Network network) {
        this.kHeight = 3;
        this.kWidth = 3;
        this.padding = 1;
        this.stride = 2;
        this.network = network;
        this.activeType = ActiveType.valueOf(str);
        this.channel = i;
        this.oChannel = i2;
        this.height = i3;
        this.width = i4;
        this.kHeight = i5;
        this.kWidth = i6;
        this.stride = i7;
        this.padding = i8;
        this.oHeight = (((i3 + (i8 * 2)) - i5) / i7) + 1;
        this.oWidth = (((i4 + (i8 * 2)) - i6) / i7) + 1;
        initLayers();
    }

    public void initLayers() {
        this.convLayer = new ConvolutionLayer(this.channel, this.oChannel, this.width, this.height, this.kHeight, this.kWidth, this.padding, this.stride, false, this.network, this.activeType);
        this.convLayer.setUpdater(UpdaterFactory.create(this.network.updater, this.network.updaterParams));
        this.bnLayer = new BNLayer(this.convLayer);
        switch (this.activeType) {
            case sigmoid:
                this.activeLayer = new SigmodLayer(this.bnLayer);
                break;
            case relu:
                this.activeLayer = new ReluLayer(this.bnLayer);
                break;
            case leaky_relu:
                this.activeLayer = new LeakyReluLayer(this.bnLayer);
                break;
            case tanh:
                this.activeLayer = new TanhLayer(this.bnLayer);
                break;
            case silu:
                this.activeLayer = new SiLULayer(this.bnLayer);
                break;
            default:
                throw new RuntimeException("The cbl layer is not support the [" + this.activeType + "] active function.");
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.org_delta == null || this.output.number != this.org_delta.number) {
            this.org_delta = Tensor.createTensor(this.org_delta, this.number, this.output.channel, this.output.height, this.output.width, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.convLayer.forward(this.input);
        this.bnLayer.forward(this.convLayer.output);
        this.activeLayer.forward(this.bnLayer.output);
        this.output = this.activeLayer.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.baseKernel.copy_gpu(this.delta, this.org_delta, this.delta.getDataLength(), 1, 1);
        this.activeLayer.back(this.org_delta);
        this.bnLayer.back(this.activeLayer.diff);
        this.convLayer.back(this.bnLayer.diff);
        this.diff = this.convLayer.diff;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.convLayer.update();
        this.bnLayer.update();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.cbl;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    public ConvolutionLayer getConvLayer() {
        return this.convLayer;
    }

    public BNLayer getBnLayer() {
        return this.bnLayer;
    }
}
