package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.cudnn.ConvTransposeCudnnKernel;
import com.omega.engine.nn.layer.gpu.BiasKernel;
import com.omega.engine.nn.layer.gpu.ConvBaseKernel;
import com.omega.engine.nn.model.ConvLayerInit;
import com.omega.engine.nn.model.LayerInit;
import com.omega.engine.nn.network.Network;

/* loaded from: input_file:com/omega/engine/nn/layer/ConvolutionTransposeLayer.class */
public class ConvolutionTransposeLayer extends Layer {
    public int kernelNum;
    public int kWidth;
    public int kHeight;
    public int stride;
    public int padding;
    public int dilation;
    public int output_padding;
    private ConvBaseKernel kernel;
    private BiasKernel biasKernel;

    public ConvolutionTransposeLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10) {
        this.kernelNum = 0;
        this.kWidth = 0;
        this.kHeight = 0;
        this.stride = 1;
        this.padding = 0;
        this.dilation = 1;
        this.output_padding = 0;
        this.kernelNum = i2;
        this.channel = i;
        this.width = i3;
        this.height = i4;
        this.kWidth = i5;
        this.kHeight = i6;
        this.padding = i7;
        this.stride = i8;
        this.dilation = i9;
        this.output_padding = i10;
        this.hasParams = true;
        initParam();
    }

    public ConvolutionTransposeLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, boolean z) {
        this.kernelNum = 0;
        this.kWidth = 0;
        this.kHeight = 0;
        this.stride = 1;
        this.padding = 0;
        this.dilation = 1;
        this.output_padding = 0;
        this.kernelNum = i2;
        this.channel = i;
        this.width = i3;
        this.height = i4;
        this.kWidth = i5;
        this.kHeight = i6;
        this.padding = i7;
        this.stride = i8;
        this.dilation = i9;
        this.output_padding = i10;
        this.hasBias = z;
        this.hasParams = true;
        initParam();
    }

    public ConvolutionTransposeLayer(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, boolean z, Network network) {
        this.kernelNum = 0;
        this.kWidth = 0;
        this.kHeight = 0;
        this.stride = 1;
        this.padding = 0;
        this.dilation = 1;
        this.output_padding = 0;
        this.kernelNum = i2;
        this.channel = i;
        this.width = i3;
        this.height = i4;
        this.kWidth = i5;
        this.kHeight = i6;
        this.padding = i7;
        this.stride = i8;
        this.dilation = i9;
        this.output_padding = i10;
        this.hasBias = z;
        this.network = network;
        this.hasParams = true;
        initParam();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
        int i = this.kernelNum * this.channel * this.kHeight * this.kWidth;
        this.oChannel = this.kernelNum;
        this.oWidth = (((this.width - 1) * this.stride) - (2 * this.padding)) + (this.dilation * (this.kWidth - 1)) + this.output_padding + 1;
        this.oHeight = (((this.height - 1) * this.stride) - (2 * this.padding)) + (this.dilation * (this.kHeight - 1)) + this.output_padding + 1;
        this.weight = new Tensor(this.kernelNum, this.channel, this.kHeight, this.kWidth, RandomUtils.kaiming_uniform(i, this.channel * this.kHeight * this.kWidth, this.paramsInit), true);
        this.bias = new Tensor(1, 1, 1, this.kernelNum, RandomUtils.kaimingUniformBias(this.kernelNum, this.channel * this.kHeight * this.kWidth), true);
        if (this.network != null) {
            this.diffB = this.network.createParamterGrad(1, 1, 1, this.kernelNum, true);
            this.diffW = this.network.createParamterGrad(this.kernelNum, this.channel, this.kHeight, this.kWidth, true);
        } else {
            this.diffB = new Tensor(1, 1, 1, this.kernelNum, true);
            this.diffW = new Tensor(this.kernelNum, this.channel, this.kHeight, this.kWidth, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        if (this.output == null || this.number != this.output.number) {
            this.output = Tensor.createTensor(this.output, this.number, this.oChannel, this.oHeight, this.oWidth, true);
        }
        if (this.kernel == null) {
            if (this.network.CUDNN) {
                this.kernel = new ConvTransposeCudnnKernel(this.network, this.channel, this.height, this.width, this.kernelNum, this.kHeight, this.kWidth, this.stride, this.padding, this.dilation, this.output_padding);
            }
            if (this.hasBias) {
                this.biasKernel = new BiasKernel();
            }
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.diff == null || this.number != this.diff.number) {
            this.diff = new Tensor(this.number, this.channel, this.height, this.width, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.kernel.convTranspose(this.input, this.weight, this.output);
        if (this.hasBias) {
            this.biasKernel.addConvBias(this.output, this.bias);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.kernel.dw(this.input, this.delta, this.diffW);
        if (this.hasBias) {
            this.biasKernel.backwardConvBias(this.diffB, this.delta);
        }
        if (this.PROPAGATE_DOWN || this.network.PROPAGATE_DOWN) {
            this.kernel.dx(this.delta, this.weight, this.diff);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        if (this.freeze) {
            return;
        }
        if (this.updater != null) {
            this.updater.update(this);
            return;
        }
        for (int i = 0; i < this.weight.getDataLength(); i++) {
            float[] fArr = this.weight.data;
            int i2 = i;
            fArr[i2] = fArr[i2] - (this.learnRate * this.diffW.data[i]);
        }
        for (int i3 = 0; i3 < this.bias.getDataLength(); i3++) {
            float[] fArr2 = this.bias.data;
            int i4 = i3;
            fArr2[i4] = fArr2[i4] - (this.learnRate * this.diffB.data[i3]);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.conv_transpose;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerInit save() {
        return new ConvLayerInit(this);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
