package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.updater.UpdaterFactory;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/omega/engine/nn/layer/TransformerNanoDecoder.class */
public class TransformerNanoDecoder extends Layer {
    private int time;
    private int vocab_size;
    private int embedDim;
    private boolean bias;
    private boolean dropout;
    private int headNum;
    private int n_layers;
    private EmbeddingIDLayer src_emb;
    private EmbeddingIDLayer pos_emb;
    private List<TransformerBlock> decoderLayers;
    private LNLayer ln;
    private DropoutLayer dropoutLayer;
    private BaseKernel baseKernel;
    private Tensor positions;

    public TransformerNanoDecoder(int i, int i2, int i3, int i4, int i5, boolean z, boolean z2) {
        this.embedDim = 0;
        this.bias = false;
        this.dropout = false;
        this.headNum = 8;
        this.n_layers = 6;
        this.headNum = i3;
        this.n_layers = i2;
        this.vocab_size = i;
        this.time = i4;
        this.embedDim = i5;
        this.bias = z;
        this.dropout = z2;
        this.channel = 1;
        this.height = 1;
        this.width = i5;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i5;
        initLayers();
    }

    public TransformerNanoDecoder(int i, int i2, int i3, int i4, int i5, boolean z, boolean z2, Network network) {
        this.embedDim = 0;
        this.bias = false;
        this.dropout = false;
        this.headNum = 8;
        this.n_layers = 6;
        this.headNum = i3;
        this.n_layers = i2;
        this.network = network;
        if (this.updater == null) {
            setUpdater(UpdaterFactory.create(network.updater, network.updaterParams));
        }
        this.vocab_size = i;
        this.time = i4;
        this.embedDim = i5;
        this.bias = z;
        this.dropout = z2;
        this.channel = 1;
        this.height = 1;
        this.width = i5;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i5;
        initLayers();
    }

    public void initLayers() {
        this.src_emb = new EmbeddingIDLayer(this.vocab_size, this.embedDim, this.network);
        this.src_emb.weight = new Tensor(1, 1, this.src_emb.width, this.src_emb.oWidth, RandomUtils.uniform(this.src_emb.width * this.src_emb.oWidth, 0.0f, 0.02f), true);
        this.pos_emb = new EmbeddingIDLayer(this.time, this.embedDim, this.network);
        this.pos_emb.weight = new Tensor(1, 1, this.pos_emb.width, this.pos_emb.oWidth, RandomUtils.uniform(this.pos_emb.width * this.pos_emb.oWidth, 0.0f, 0.02f), true);
        this.decoderLayers = new ArrayList();
        for (int i = 0; i < this.n_layers; i++) {
            this.decoderLayers.add(new TransformerBlock(this.headNum, this.time, this.embedDim, this.bias, this.dropout, this.network));
        }
        this.ln = new LNLayer(this.decoderLayers.get(this.n_layers - 1), this.bias);
        if (this.dropout) {
            this.dropoutLayer = new DropoutLayer(0.1f, this.src_emb);
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.input.number;
        this.time = this.network.time;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.src_emb.forward(this.input);
        this.pos_emb.forward(this.positions);
        TensorOP.add(this.src_emb.getOutput(), this.pos_emb.getOutput(), this.src_emb.getOutput());
        Tensor output = this.src_emb.getOutput();
        if (this.dropout) {
            this.dropoutLayer.forward(output);
            output = this.dropoutLayer.getOutput();
        }
        for (int i = 0; i < this.n_layers; i++) {
            this.decoderLayers.get(i).forward(output);
            output = this.decoderLayers.get(i).getOutput();
        }
        this.ln.forward(output);
        this.output = this.ln.getOutput();
    }

    public void output(Tensor tensor) {
        this.src_emb.forward(this.input);
        this.pos_emb.forward(tensor);
        TensorOP.add(this.src_emb.getOutput(), this.pos_emb.getOutput(), this.src_emb.getOutput());
        Tensor output = this.src_emb.getOutput();
        if (this.dropout) {
            this.dropoutLayer.forward(output);
            output = this.dropoutLayer.getOutput();
        }
        for (int i = 0; i < this.n_layers; i++) {
            this.decoderLayers.get(i).forward(output);
            output = this.decoderLayers.get(i).getOutput();
        }
        this.output = output;
    }

    public void output(Tensor tensor, Tensor tensor2) {
        this.src_emb.forward(this.input);
        this.pos_emb.forward(tensor2);
        TensorOP.add(this.src_emb.getOutput(), this.pos_emb.getOutput(), this.src_emb.getOutput());
        Tensor output = this.src_emb.getOutput();
        if (this.dropout) {
            this.dropoutLayer.forward(output);
            output = this.dropoutLayer.getOutput();
        }
        for (int i = 0; i < this.n_layers; i++) {
            this.decoderLayers.get(i).forward(output, tensor);
            output = this.decoderLayers.get(i).getOutput();
        }
        this.ln.forward(output);
        this.output = this.ln.getOutput();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        Tensor tensor = this.delta;
        for (int i = this.n_layers - 1; i >= 0; i--) {
            this.decoderLayers.get(i).back(tensor);
            tensor = this.decoderLayers.get(i).diff;
        }
        if (this.dropout) {
            this.dropoutLayer.back(tensor);
            tensor = this.dropoutLayer.diff;
        }
        this.src_emb.back(tensor);
        this.pos_emb.back(tensor);
        this.diff = this.src_emb.diff;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        setInput();
        init();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        setInput(tensor);
        init();
        output();
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        setInput(tensor);
        init();
        output(tensor2);
    }

    public void forward(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        setInput(tensor);
        init();
        output(tensor2, tensor3);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.src_emb.update();
        this.pos_emb.update();
        for (int i = 0; i < this.n_layers; i++) {
            this.decoderLayers.get(i).update();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.transformer_decoder;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }
}
