package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.cudnn.RNNCudnnKernelV8;
import com.omega.engine.nn.layer.gpu.RNNBaseKernel;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.RNN;

/* loaded from: input_file:com/omega/engine/nn/layer/RNNBlockLayer.class */
public class RNNBlockLayer extends BaseRNNLayer {
    private int time;
    private int inputSize;
    private int hiddenSize;
    private int layerNum;
    private RNNBaseKernel kernel;
    private int rnnMode;
    private boolean bidirectional;
    private float dropout;
    private Tensor hx;
    private Tensor cx;
    private Tensor hy;
    private Tensor cy;
    private Tensor dhx;
    private Tensor dcx;
    private Tensor dhy;
    private Tensor dcy;
    private int hidden_len;

    public RNNBlockLayer(int i, int i2, int i3, int i4, boolean z, boolean z2, float f) {
        this.time = 0;
        this.layerNum = 1;
        this.rnnMode = 0;
        this.bidirectional = false;
        this.dropout = 0.0f;
        this.hidden_len = 0;
        this.hasBias = z2;
        this.time = i;
        this.inputSize = i2;
        this.hiddenSize = i3;
        this.rnnMode = i4;
        this.bidirectional = z;
        this.dropout = f;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i3;
        initKernel();
    }

    public RNNBlockLayer(int i, int i2, int i3, int i4, int i5, boolean z, boolean z2, float f) {
        this.time = 0;
        this.layerNum = 1;
        this.rnnMode = 0;
        this.bidirectional = false;
        this.dropout = 0.0f;
        this.hidden_len = 0;
        this.hasBias = z2;
        this.layerNum = i2;
        this.time = i;
        this.inputSize = i3;
        this.hiddenSize = i4;
        this.rnnMode = i5;
        this.bidirectional = z;
        this.dropout = f;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i4;
        initKernel();
    }

    public RNNBlockLayer(int i, int i2, int i3, int i4, int i5, boolean z, boolean z2, float f, Network network) {
        this.time = 0;
        this.layerNum = 1;
        this.rnnMode = 0;
        this.bidirectional = false;
        this.dropout = 0.0f;
        this.hidden_len = 0;
        this.hasBias = z2;
        this.layerNum = i2;
        this.network = network;
        this.time = i;
        this.inputSize = i3;
        this.hiddenSize = i4;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i4;
        this.rnnMode = i5;
        this.bidirectional = z;
        this.dropout = f;
        initKernel();
    }

    public void initKernel() {
        if (this.kernel == null) {
            this.kernel = new RNNCudnnKernelV8(this.time, this.layerNum, this.inputSize, this.hiddenSize, this.bidirectional, this.rnnMode, this.dropout, this.hasBias);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        this.time = ((RNN) this.network).time;
        if (this.time != this.kernel.seqLength) {
            this.kernel.seqLength = this.time;
        }
        this.hidden_len = (this.number / this.time) * this.layerNum;
        if (getHx() == null || getHx().number != this.hidden_len) {
            setHx(Tensor.createTensor(getHx(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getCx() == null || getCx().number != this.hidden_len) {
            setCx(Tensor.createTensor(getCx(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getHy() == null || getHy().number != this.hidden_len) {
            setHy(Tensor.createTensor(getHy(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getCy() == null || getCy().number != this.hidden_len) {
            setCy(Tensor.createTensor(getCy(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (this.output == null || this.number != this.output.number) {
            this.output = Tensor.createTensor(this.output, this.number, 1, 1, this.hiddenSize, true);
        }
    }

    public void init(int i, int i2) {
        this.number = i2;
        this.time = i;
        if (this.time != this.kernel.seqLength) {
            this.kernel.seqLength = this.time;
        }
        this.hidden_len = (i2 / i) * this.layerNum;
        if (getHx() == null || getHx().number != this.hidden_len) {
            setHx(Tensor.createTensor(getHx(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getCx() == null || getCx().number != this.hidden_len) {
            setCx(Tensor.createTensor(getCx(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getHy() == null || getHy().number != this.hidden_len) {
            setHy(Tensor.createTensor(getHy(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (getCy() == null || getCy().number != this.hidden_len) {
            setCy(Tensor.createTensor(getCy(), this.hidden_len, 1, 1, this.hiddenSize, true));
        }
        if (this.output == null || this.number != this.output.number) {
            this.output = Tensor.createTensor(this.output, i2, 1, 1, this.hiddenSize, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.dhx == null || this.dhx.number != this.hidden_len) {
            this.dhx = Tensor.createTensor(this.dhx, this.hidden_len, 1, 1, this.hiddenSize, true);
        }
        if (this.dhy == null || this.dhy.number != this.hidden_len) {
            this.dhy = Tensor.createTensor(this.dhy, this.hidden_len, 1, 1, this.hiddenSize, true);
        }
        if (this.dcx == null || this.dcx.number != this.hidden_len) {
            this.dcx = Tensor.createTensor(this.dcx, this.hidden_len, 1, 1, this.hiddenSize, true);
        }
        if (this.dcy == null || this.dcy.number != this.hidden_len) {
            this.dcy = Tensor.createTensor(this.dcy, this.hidden_len, 1, 1, this.hiddenSize, true);
        }
        if (this.diff == null || this.number != this.diff.number) {
            this.diff = new Tensor(this.number, 1, 1, this.inputSize, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
        if (this.weight == null) {
            int weightSize = (int) (this.kernel.weightSize() / 4);
            this.weight = new Tensor(1, 1, 1, weightSize, true);
            this.diffW = new Tensor(1, 1, 1, weightSize, true);
            this.kernel.initWeights(this.weight);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.kernel.init(this.input.number, this.time);
        initParam();
        this.kernel.forward(this.network.RUN_MODEL, this.input, getHx(), getCx(), this.weight, this.output, getHy(), getCy());
    }

    public void output(Tensor tensor, Tensor tensor2) {
        this.kernel.init(this.input.number, this.time);
        initParam();
        this.kernel.forward(this.network.RUN_MODEL, this.input, tensor, tensor2, this.weight, this.output, getHy(), getCy());
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.kernel.dx(this.delta, this.dhy, this.dcy, this.output, this.hx, this.cx, this.weight, this.diff, this.dhx, this.dcx);
        this.kernel.dw(this.delta, this.output, this.input, this.hx, this.diffW);
    }

    public void diff(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        this.kernel.dx(this.delta, tensor3, tensor4, this.output, tensor, tensor2, this.weight, this.diff, this.dhx, this.dcx);
        this.kernel.dw(this.delta, this.output, this.input, tensor, this.diffW);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public void back(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5) {
        initBack();
        setDelta(tensor);
        diff(tensor2, tensor3, tensor4, tensor5);
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        if (this.freeze) {
            return;
        }
        if (this.updater != null) {
            this.updater.update(this);
            return;
        }
        for (int i = 0; i < this.weight.getDataLength(); i++) {
            float[] fArr = this.weight.data;
            int i2 = i;
            fArr[i2] = fArr[i2] - (this.learnRate * this.diffW.data[i]);
        }
        if (this.hasBias) {
            for (int i3 = 0; i3 < this.bias.getDataLength(); i3++) {
                float[] fArr2 = this.bias.data;
                int i4 = i3;
                fArr2[i4] = fArr2[i4] - (this.learnRate * this.diffB.data[i3]);
            }
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.rnn;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public void forward(int i, int i2) {
        init(i, i2);
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public void forward(Tensor tensor, Tensor tensor2, Tensor tensor3, int i) {
        init(i, tensor.number);
        setInput(tensor);
        output(tensor2, tensor3);
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getHx() {
        return this.hx;
    }

    public void setHx(Tensor tensor) {
        this.hx = tensor;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getCx() {
        return this.cx;
    }

    public void setCx(Tensor tensor) {
        this.cx = tensor;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getHy() {
        return this.hy;
    }

    public void setHy(Tensor tensor) {
        this.hy = tensor;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getCy() {
        return this.cy;
    }

    public void setCy(Tensor tensor) {
        this.cy = tensor;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getDhx() {
        return this.dhx;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getDcx() {
        return this.dcx;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getDhy() {
        return this.dhy;
    }

    @Override // com.omega.engine.nn.layer.BaseRNNLayer
    public Tensor getDcy() {
        return this.dcy;
    }
}
