package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.CNN;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/TransformerDecoderLayer.class */
public class TransformerDecoderLayer extends Layer {
    private int time;
    private int embedDim;
    private int nChannel;
    private boolean bias;
    private boolean layer_norm;
    private int headNum = 12;
    private MultiHeadAttentionLayer attn;
    private PoswiseFeedForwardLinearLayer feed_forward;
    private LNLayer ln1;
    private LNLayer ln2;
    private BaseKernel baseKernel;
    private Tensor ln1i;
    private Tensor ln2i;

    public TransformerDecoderLayer(int i, int i2, int i3, boolean z, boolean z2) {
        this.embedDim = 0;
        this.nChannel = 1;
        this.bias = false;
        this.layer_norm = false;
        this.time = i;
        this.embedDim = i2;
        this.nChannel = i3;
        this.bias = z;
        this.layer_norm = z2;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i2;
        initLayers();
    }

    public TransformerDecoderLayer(int i, int i2, int i3, boolean z, boolean z2, Network network) {
        this.embedDim = 0;
        this.nChannel = 1;
        this.bias = false;
        this.layer_norm = false;
        this.network = network;
        if (this.updater == null) {
            setUpdater(UpdaterFactory.create(network.updater, network.updaterParams));
        }
        this.time = i;
        this.embedDim = i2;
        this.nChannel = i3;
        this.bias = z;
        this.layer_norm = z2;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i2;
        initLayers();
    }

    public void initLayers() {
        this.baseKernel = new BaseKernel();
        this.attn = new MultiHeadAttentionLayer(this.embedDim, this.headNum, this.time, this.bias, this.layer_norm, this.network);
        this.feed_forward = new PoswiseFeedForwardLinearLayer(this.embedDim, this.nChannel, this.bias, this.layer_norm, this.network);
        this.ln1 = new LNLayer(this.attn);
        this.ln2 = new LNLayer(this.feed_forward);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.input.number;
        if (this.ln1i == null || this.number != this.ln1i.number) {
            this.ln1i = Tensor.createTensor(this.ln1i, this.number, this.input.channel, this.input.height, this.input.width, true);
        }
        if (this.ln2i == null || this.number != this.ln2i.number) {
            this.ln2i = Tensor.createTensor(this.ln2i, this.number, this.input.channel, this.input.height, this.input.width, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.attn.forward(this.input);
        this.feed_forward.forward(this.attn.getOutput());
        this.output = this.feed_forward.getOutput();
    }

    public void output(Tensor tensor) {
        this.attn.forward(this.input, tensor);
        TensorOP.add(this.attn.getOutput(), this.input, this.ln1i);
        this.ln1.forward(this.ln1i);
        this.feed_forward.forward(this.ln1.getOutput());
        TensorOP.add(this.feed_forward.getOutput(), this.ln1.getOutput(), this.ln2i);
        this.ln2.forward(this.ln2i);
        this.output = this.ln2.getOutput();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.ln2.back(this.delta);
        this.baseKernel.copy_gpu(this.ln2.diff, this.ln2i, this.ln2.diff.getDataLength(), 1, 1);
        this.feed_forward.back(this.ln2.diff);
        TensorOP.add(this.feed_forward.diff, this.ln2i, this.ln2.getOutput());
        this.ln1.back(this.ln2.getOutput());
        this.baseKernel.copy_gpu(this.ln1.diff, this.ln1i, this.ln1.diff.getDataLength(), 1, 1);
        this.attn.back(this.ln1.diff);
        TensorOP.add(this.attn.diff, this.ln1i, this.ln1i);
        this.diff = this.ln1i;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        setInput();
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        setInput(tensor);
        init();
        output();
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        setInput(tensor);
        init();
        output(tensor2);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.attn.update();
        this.feed_forward.update();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.transformer_decoder;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    public static void main(String[] strArr) {
        CNN cnn = new CNN(null);
        cnn.CUDNN = true;
        cnn.number = 5;
        Tensor tensor = new Tensor(5, 10, 1, 8, RandomUtils.order(5 * 10 * 8, 0.1f, 0.1f), true);
        tensor.showShape();
        tensor.showDM();
        Tensor tensor2 = new Tensor(5, 10, 1, 8, MatrixUtils.val(5 * 10 * 8, 1.0f), true);
        TransformerDecoderLayer transformerDecoderLayer = new TransformerDecoderLayer(10, 8, 4, false, true, cnn);
        transformerDecoderLayer.forward(tensor);
        transformerDecoderLayer.getOutput().showDM();
        transformerDecoderLayer.back(tensor2);
        transformerDecoderLayer.diff.showDM();
    }
}
