package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.active.ActiveType;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.active.ActiveFunctionLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.active.SigmodLayer;
import com.omega.engine.nn.layer.active.TanhLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.RNN;

/* loaded from: input_file:com/omega/engine/nn/layer/LSTMLayer.class */
public class LSTMLayer extends Layer {
    private int time;
    private int inputSize;
    private int hiddenSize;
    private boolean bias;
    private FullyLayer fxl;
    private FullyLayer ixl;
    private FullyLayer gxl;
    private FullyLayer oxl;
    private FullyLayer fhl;
    private FullyLayer ihl;
    private FullyLayer ghl;
    private FullyLayer ohl;
    private ActiveFunctionLayer fa;
    private ActiveFunctionLayer ia;
    private ActiveFunctionLayer ga;
    private ActiveFunctionLayer oa;
    private ActiveFunctionLayer ha;
    private Tensor f;
    private Tensor i;
    private Tensor g;
    private Tensor c;
    private Tensor o;
    private Tensor h;
    private Tensor temp;
    private Tensor h_diff;
    private Tensor c_diff;
    private Tensor detlaXo;
    private Tensor d_tanhc;
    private BaseKernel baseKernel;

    public LSTMLayer(int i, int i2, int i3, boolean z) {
        this.time = 0;
        this.bias = false;
        this.time = i3;
        this.inputSize = i;
        this.hiddenSize = i2;
        this.bias = z;
        initLayers();
    }

    public LSTMLayer(int i, int i2, int i3, boolean z, Network network) {
        this.time = 0;
        this.bias = false;
        this.network = network;
        this.time = i3;
        this.inputSize = i;
        this.hiddenSize = i2;
        this.bias = z;
        initLayers();
    }

    public void initLayers() {
        this.fxl = FullyLayer.createRNNCell(this.inputSize, this.hiddenSize, this.time, this.bias, this.network);
        this.ixl = FullyLayer.createRNNCell(this.inputSize, this.hiddenSize, this.time, this.bias, this.network);
        this.gxl = FullyLayer.createRNNCell(this.inputSize, this.hiddenSize, this.time, this.bias, this.network);
        this.oxl = FullyLayer.createRNNCell(this.inputSize, this.hiddenSize, this.time, this.bias, this.network);
        this.fhl = FullyLayer.createRNNCell(this.hiddenSize, this.hiddenSize, this.time, false, this.network);
        this.ihl = FullyLayer.createRNNCell(this.hiddenSize, this.hiddenSize, this.time, false, this.network);
        this.ghl = FullyLayer.createRNNCell(this.hiddenSize, this.hiddenSize, this.time, false, this.network);
        this.ohl = FullyLayer.createRNNCell(this.hiddenSize, this.hiddenSize, this.time, false, this.network);
        this.fa = createActiveLayer(ActiveType.sigmoid, this.fhl);
        this.ia = createActiveLayer(ActiveType.sigmoid, this.ihl);
        this.ga = createActiveLayer(ActiveType.tanh, this.ghl);
        this.oa = createActiveLayer(ActiveType.sigmoid, this.ohl);
        this.ha = createActiveLayer(ActiveType.tanh, this.fhl);
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
    }

    public ActiveFunctionLayer createActiveLayer(ActiveType activeType, Layer layer) {
        switch (activeType) {
            case sigmoid:
                return new SigmodLayer(layer);
            case relu:
                return new ReluLayer(layer);
            case leaky_relu:
                return new LeakyReluLayer(layer);
            case tanh:
                return new TanhLayer(layer);
            default:
                throw new RuntimeException("The rnn layer is not support the [" + activeType + "] active function.");
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        this.time = ((RNN) this.network).time;
        if (this.h == null || this.h.number != this.number) {
            this.f = Tensor.createTensor(this.f, this.number, 1, 1, this.hiddenSize, true);
            this.i = Tensor.createTensor(this.i, this.number, 1, 1, this.hiddenSize, true);
            this.g = Tensor.createTensor(this.g, this.number, 1, 1, this.hiddenSize, true);
            this.c = Tensor.createTensor(this.c, this.number, 1, 1, this.hiddenSize, true);
            this.o = Tensor.createTensor(this.o, this.number, 1, 1, this.hiddenSize, true);
            this.h = Tensor.createTensor(this.h, this.number, 1, 1, this.hiddenSize, true);
            this.temp = Tensor.createTensor(this.temp, this.number, 1, 1, this.hiddenSize, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        int i = this.number / this.time;
        if (this.detlaXo == null || this.detlaXo.number != i) {
            this.detlaXo = Tensor.createTensor(this.detlaXo, i, 1, 1, this.hiddenSize, true);
            this.d_tanhc = Tensor.createTensor(this.d_tanhc, i, 1, 1, this.hiddenSize, true);
        }
        if (this.h_diff == null || this.h_diff.number != this.number) {
            this.h_diff = Tensor.createTensor(this.h_diff, this.number, 1, 1, this.hiddenSize, true);
            this.c_diff = Tensor.createTensor(this.c_diff, this.number, 1, 1, this.hiddenSize, true);
        }
        if (this.diff == null || this.diff.number != this.number) {
            this.diff = Tensor.createTensor(this.diff, this.number, 1, 1, this.inputSize, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        int i = this.number / this.time;
        int onceSize = i * this.h.getOnceSize();
        if (this.input != null) {
            this.c.clearGPU();
            for (int i2 = 0; i2 < this.time; i2++) {
                this.fxl.forward(this.input, i, i2);
                this.ixl.forward(this.input, i, i2);
                this.gxl.forward(this.input, i, i2);
                this.oxl.forward(this.input, i, i2);
                this.fhl.forward(this.h, i, i2 - 1, i2);
                this.ihl.forward(this.h, i, i2 - 1, i2);
                this.ghl.forward(this.h, i, i2 - 1, i2);
                this.ohl.forward(this.h, i, i2 - 1, i2);
                TensorOP.add(this.fxl.getOutput(), this.fhl.getOutput(), this.f, i2 * onceSize, onceSize);
                TensorOP.add(this.ixl.getOutput(), this.ihl.getOutput(), this.i, i2 * onceSize, onceSize);
                TensorOP.add(this.gxl.getOutput(), this.ghl.getOutput(), this.g, i2 * onceSize, onceSize);
                TensorOP.add(this.oxl.getOutput(), this.ohl.getOutput(), this.o, i2 * onceSize, onceSize);
                this.fa.forward(this.f, i, i2);
                this.ia.forward(this.i, i, i2);
                this.ga.forward(this.g, i, i2);
                this.oa.forward(this.o, i, i2);
                TensorOP.mul(this.ia.getOutput(), this.ga.getOutput(), this.temp, i2 * onceSize, onceSize);
                if (i2 > 0) {
                    TensorOP.mul(this.c, this.fa.getOutput(), this.c, (i2 - 1) * onceSize, i2 * onceSize, i2 * onceSize, onceSize);
                }
                TensorOP.add(this.temp, this.c, this.c, i2 * onceSize, onceSize);
                this.ha.forward(this.c, i, i2);
                TensorOP.mul(this.oa.getOutput(), this.ha.getOutput(), this.h, i2 * onceSize, onceSize);
            }
        }
        this.output = this.h;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        int i = this.number / this.time;
        int i2 = i * this.hiddenSize;
        this.fxl.clear();
        this.ixl.clear();
        this.gxl.clear();
        this.oxl.clear();
        this.fhl.clear();
        this.ihl.clear();
        this.ghl.clear();
        this.ohl.clear();
        this.h_diff.clearGPU();
        this.c_diff.clearGPU();
        for (int i3 = this.time - 1; i3 >= 0; i3--) {
            if (i3 < this.time - 1) {
                this.baseKernel.axpy_gpu(this.h_diff, this.delta, i2, 1.0f, i3 * i2, 1, i3 * i2, 1);
            }
            TensorOP.mul(this.delta, this.oa.getOutput(), this.detlaXo, i3 * i2, i3 * i2, 0, i2);
            TensorOP.mul(this.ha.getOutput(), this.ha.getOutput(), this.d_tanhc, i3 * i2, i3 * i2, 0, i2);
            TensorOP.sub(1.0f, this.d_tanhc, this.d_tanhc, 0, i2);
            TensorOP.mul(this.detlaXo, this.d_tanhc, this.detlaXo, 0, i2);
            if (i3 < this.time - 1) {
                TensorOP.mul(this.detlaXo, this.fa.getOutput(), this.c_diff, 0, i3 * i2, (i3 - 1) * i2, i2);
                TensorOP.add(this.detlaXo, this.c_diff, this.detlaXo, 0, i3 * i2, 0, i2);
            }
            TensorOP.mul(this.delta, this.ha.getOutput(), this.temp, i3 * i2, i2);
            this.oa.back(this.temp, i, i3);
            TensorOP.mul(this.detlaXo, this.c, this.temp, 0, (i3 - 1) * i2, i3 * i2, i2);
            this.fa.back(this.temp, i, i3);
            TensorOP.mul(this.detlaXo, this.c, this.temp, 0, i3 * i2, i3 * i2, i2);
            this.ia.back(this.temp, i, i3);
            TensorOP.mul(this.detlaXo, this.ia.getOutput(), this.temp, 0, i3 * i2, i3 * i2, i2);
            this.ga.back(this.temp, i, i3);
            this.fxl.back(this.fa.diff, i, i3);
            this.ixl.back(this.ia.diff, i, i3);
            this.gxl.back(this.ga.diff, i, i3);
            this.oxl.back(this.oa.diff, i, i3);
            this.fhl.back(this.fa.diff, i, i3, i3, i3 - 1);
            this.ihl.back(this.ia.diff, i, i3, i3, i3 - 1);
            this.ghl.back(this.ga.diff, i, i3, i3, i3 - 1);
            this.ohl.back(this.oa.diff, i, i3, i3, i3 - 1);
            TensorOP.add(this.fhl.diff, this.ihl.diff, this.h_diff, (i3 - 1) * i2, i2);
            TensorOP.add(this.h_diff, this.ghl.diff, this.h_diff, (i3 - 1) * i2, i2);
            TensorOP.add(this.h_diff, this.ohl.diff, this.h_diff, (i3 - 1) * i2, i2);
            TensorOP.add(this.fxl.diff, this.ixl.diff, this.diff, i3 * i2, i2);
            TensorOP.add(this.diff, this.gxl.diff, this.diff, i3 * i2, i2);
            TensorOP.add(this.diff, this.oxl.diff, this.diff, i3 * i2, i2);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.fxl.update(this.number / this.time);
        this.ixl.update(this.number / this.time);
        this.gxl.update(this.number / this.time);
        this.oxl.update(this.number / this.time);
        this.fhl.update(this.number / this.time);
        this.ihl.update(this.number / this.time);
        this.ghl.update(this.number / this.time);
        this.ohl.update(this.number / this.time);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.rnn;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
