package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/TransformerBlock.class */
public class TransformerBlock extends Layer {
    private int time;
    private int headNum;
    private int embedDim;
    private boolean bias;
    private boolean dropout;
    private FastCausalSelfAttentionLayer attn;
    private LNLayer ln1;
    private MLPLayer mlp;
    private LNLayer ln2;
    private BaseKernel baseKernel;
    private Tensor tmp1;
    private Tensor tmp2;

    public TransformerBlock(int i, int i2, int i3, boolean z, boolean z2) {
        this.headNum = 8;
        this.embedDim = 0;
        this.bias = false;
        this.dropout = false;
        this.headNum = i;
        this.time = i2;
        this.embedDim = i3;
        this.bias = z;
        this.dropout = z2;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i3;
        initLayers();
    }

    public TransformerBlock(int i, int i2, int i3, boolean z, boolean z2, Network network) {
        this.headNum = 8;
        this.embedDim = 0;
        this.bias = false;
        this.dropout = false;
        this.headNum = i;
        this.network = network;
        if (this.updater == null) {
            setUpdater(UpdaterFactory.create(network.updater, network.updaterParams));
        }
        this.time = i2;
        this.embedDim = i3;
        this.bias = z;
        this.dropout = z2;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i3;
        initLayers();
    }

    public void initLayers() {
        this.ln1 = new LNLayer(this, this.bias);
        this.attn = new FastCausalSelfAttentionLayer(this.embedDim, this.headNum, this.time, this.bias, this.dropout, this.network);
        this.ln2 = new LNLayer(this.attn, this.bias);
        this.mlp = new MLPLayer(this.embedDim, this.embedDim * 4, this.bias, this.network);
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.input.number;
        this.time = this.network.time;
        if (this.tmp1 == null || this.tmp1.number != this.number) {
            this.tmp1 = Tensor.createTensor(this.tmp1, this.number, 1, 1, this.embedDim, true);
            this.tmp2 = Tensor.createTensor(this.tmp2, this.number, 1, 1, this.embedDim, true);
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.ln1.forward(this.input);
        this.attn.forward(this.ln1.getOutput());
        TensorOP.add(this.attn.getOutput(), this.input, this.tmp1);
        this.ln2.forward(this.tmp1);
        this.mlp.forward(this.ln2.getOutput());
        TensorOP.add(this.mlp.getOutput(), this.tmp1, this.tmp2);
        this.output = this.tmp2;
    }

    public void output(Tensor tensor) {
        this.ln1.forward(this.input);
        this.attn.forward(this.ln1.getOutput());
        TensorOP.add(this.attn.getOutput(), this.input, this.tmp1);
        this.ln2.forward(this.tmp1);
        this.mlp.forward(this.ln2.getOutput());
        TensorOP.add(this.mlp.getOutput(), this.tmp1, this.tmp2);
        this.output = this.tmp2;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        this.mlp.back(this.delta);
        this.ln2.back(this.mlp.diff);
        TensorOP.add(this.ln2.diff, this.delta, this.ln2.diff);
        this.attn.back(this.ln2.diff);
        this.ln1.back(this.attn.diff);
        TensorOP.add(this.ln1.diff, this.ln2.diff, this.tmp2);
        this.diff = this.tmp2;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        setInput();
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        setInput(tensor);
        init();
        output();
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        setInput(tensor);
        init();
        output(tensor2);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.ln1.update();
        this.attn.update();
        this.ln2.update();
        this.mlp.update();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.transformer_decoder;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
