package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.GPUOP;
import com.omega.engine.gpu.SoftmaxKernel;
import com.omega.engine.nn.layer.normalization.LNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.Transformer;
import com.omega.engine.updater.UpdaterFactory;
import com.omega.example.transformer.utils.ENTokenizer;

/* loaded from: input_file:com/omega/engine/nn/layer/MultiHeadAttentionLayer.class */
public class MultiHeadAttentionLayer extends Layer {
    private int time;
    private int headNum;
    private int embedDim;
    private boolean bias;
    private boolean layer_norm;
    private FullyLayer qLinerLayer;
    private FullyLayer kLinerLayer;
    private FullyLayer vLinerLayer;
    private FullyLayer oLinerLayer;
    private LNLayer lnLayer;
    private BaseKernel baseKernel;
    private Tensor qt;
    private Tensor kt;
    private Tensor vt;
    private Tensor scores;
    private Tensor weights;
    private Tensor attn_outputs;
    private Tensor ot;
    private Tensor ro;
    private SoftmaxKernel softmax;
    private int dk = 0;
    private int batchSize = 1;

    public MultiHeadAttentionLayer(int i, int i2, int i3, boolean z, boolean z2) {
        this.headNum = 1;
        this.embedDim = 0;
        this.bias = false;
        this.layer_norm = false;
        this.time = i3;
        this.layer_norm = z2;
        this.embedDim = i;
        this.headNum = i2;
        this.bias = z;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i;
        initLayers();
    }

    public MultiHeadAttentionLayer(int i, int i2, int i3, boolean z, boolean z2, Network network) {
        this.headNum = 1;
        this.embedDim = 0;
        this.bias = false;
        this.layer_norm = false;
        this.network = network;
        if (this.updater == null) {
            setUpdater(UpdaterFactory.create(network.updater, network.updaterParams));
        }
        this.time = i3;
        this.embedDim = i;
        this.headNum = i2;
        this.bias = z;
        this.layer_norm = z2;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i;
        initLayers();
    }

    public void initLayers() {
        this.qLinerLayer = new FullyLayer(this.embedDim, this.embedDim, this.bias, this.network);
        this.kLinerLayer = new FullyLayer(this.embedDim, this.embedDim, this.bias, this.network);
        this.vLinerLayer = new FullyLayer(this.embedDim, this.embedDim, this.bias, this.network);
        this.oLinerLayer = new FullyLayer(this.embedDim, this.embedDim, this.bias, this.network);
        if (this.layer_norm) {
            this.lnLayer = new LNLayer(this.oLinerLayer);
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
        if (this.softmax == null) {
            this.softmax = new SoftmaxKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        this.dk = this.embedDim / this.headNum;
        this.batchSize = this.number / this.time;
        if (this.qt == null || this.qt.number != this.batchSize) {
            this.qt = Tensor.createTensor(this.qt, this.batchSize, this.headNum, this.time, this.dk, true);
            this.kt = Tensor.createTensor(this.kt, this.batchSize, this.headNum, this.time, this.dk, true);
            this.vt = Tensor.createTensor(this.vt, this.batchSize, this.headNum, this.time, this.dk, true);
            this.scores = Tensor.createTensor(this.scores, this.batchSize, this.headNum, this.time, this.time, true);
            this.weights = Tensor.createTensor(this.weights, this.batchSize, this.headNum, this.time, this.time, true);
            this.attn_outputs = Tensor.createTensor(this.attn_outputs, this.batchSize, this.headNum, this.time, this.dk, true);
            this.ot = Tensor.createTensor(this.ot, this.batchSize, this.time, this.headNum, this.dk, true);
            this.ro = Tensor.createTensor(this.ro, this.batchSize * this.time, 1, 1, this.embedDim, true);
        }
        resize();
    }

    public void resize() {
        this.qt.viewOrg();
        this.kt.viewOrg();
        this.vt.viewOrg();
        this.scores.viewOrg();
        this.weights.viewOrg();
        this.attn_outputs.viewOrg();
        this.ot.viewOrg();
        if (this.qLinerLayer.getOutput() != null) {
            this.qLinerLayer.getOutput().viewOrg();
            this.kLinerLayer.getOutput().viewOrg();
            this.vLinerLayer.getOutput().viewOrg();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.qLinerLayer.forward(this.input);
        this.kLinerLayer.forward(this.input);
        this.vLinerLayer.forward(this.input);
        Tensor view = this.qLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        Tensor view2 = this.kLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        Tensor view3 = this.vLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        TensorOP.permute(view, this.qt, new int[]{0, 2, 1, 3});
        TensorOP.permute(view2, this.kt, new int[]{0, 2, 1, 3});
        TensorOP.permute(view3, this.vt, new int[]{0, 2, 1, 3});
        scaledDotProductAttention(this.qt, this.kt, this.vt, null);
        TensorOP.permute(this.attn_outputs, this.ot, new int[]{0, 2, 1, 3});
        this.ot.view(this.batchSize * this.time, 1, 1, this.headNum * this.dk);
        this.oLinerLayer.forward(this.ot);
        TensorOP.add(this.oLinerLayer.getOutput(), this.input, this.ro);
        if (!this.layer_norm) {
            this.output = this.ro;
        } else {
            this.lnLayer.forward(this.ro);
            this.output = this.lnLayer.getOutput();
        }
    }

    public void output(Tensor tensor) {
        this.qLinerLayer.forward(this.input);
        this.kLinerLayer.forward(this.input);
        this.vLinerLayer.forward(this.input);
        Tensor view = this.qLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        Tensor view2 = this.kLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        Tensor view3 = this.vLinerLayer.getOutput().view(this.batchSize, this.time, this.headNum, this.dk);
        TensorOP.permute(view, this.qt, new int[]{0, 2, 1, 3});
        TensorOP.permute(view2, this.kt, new int[]{0, 2, 1, 3});
        TensorOP.permute(view3, this.vt, new int[]{0, 2, 1, 3});
        scaledDotProductAttention(this.qt, this.kt, this.vt, tensor);
        TensorOP.permute(this.attn_outputs, this.ot, new int[]{0, 2, 1, 3});
        this.ot.view(this.batchSize * this.time, 1, 1, this.headNum * this.dk);
        this.oLinerLayer.forward(this.ot);
        TensorOP.add(this.oLinerLayer.getOutput(), this.input, this.ro);
        if (!this.layer_norm) {
            this.output = this.ro;
        } else {
            this.lnLayer.forward(this.ro);
            this.output = this.lnLayer.getOutput();
        }
    }

    public void scaledDotProductAttention(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        GPUOP.getInstance().bmm(tensor.getGpuData(), tensor2.getGpuData(), this.scores.getGpuData(), tensor.number * tensor.channel, tensor.height, tensor2.height, tensor.width, 0, 1, (float) (1.0d / Math.sqrt(this.dk)), 0.0f);
        if (tensor4 != null) {
            this.softmax.softmaxMask(this.scores, tensor4, this.weights, -1.0E9f);
        } else {
            this.softmax.softmax(this.scores, this.weights);
        }
        GPUOP.getInstance().bmm(this.weights.getGpuData(), tensor3.getGpuData(), this.attn_outputs.getGpuData(), this.weights.number * this.weights.channel, this.weights.height, tensor3.width, this.weights.width, 0, 0, 1.0f, 0.0f);
    }

    public void scaledDotProductAttentionBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4, Tensor tensor5, Tensor tensor6, Tensor tensor7) {
        tensor7.view(tensor3.shape());
        GPUOP.getInstance().bmm(this.weights.getGpuData(), tensor4.getGpuData(), tensor7.getGpuData(), this.weights.number * this.weights.channel, this.weights.width, tensor4.width, this.weights.height, 1, 0, 1.0f, 0.0f);
        GPUOP.getInstance().bmm(tensor4.getGpuData(), tensor3.getGpuData(), this.scores.getGpuData(), tensor4.number * tensor4.channel, tensor4.height, tensor3.height, tensor3.width, 0, 1, 1.0f, 0.0f);
        this.softmax.backward_noloss(this.weights, this.scores, this.scores);
        float sqrt = (float) (1.0d / Math.sqrt(this.dk));
        tensor6.view(tensor2.shape());
        GPUOP.getInstance().bmm(this.scores.getGpuData(), tensor.getGpuData(), tensor6.getGpuData(), this.scores.number * this.scores.channel, this.scores.width, tensor.width, this.scores.height, 1, 0, sqrt, 0.0f);
        tensor5.view(tensor.shape());
        GPUOP.getInstance().bmm(this.scores.getGpuData(), tensor2.getGpuData(), tensor5.getGpuData(), this.scores.number * this.scores.channel, this.scores.height, tensor2.width, this.scores.width, 0, 0, sqrt, 0.0f);
    }

    public void scaledDotProductAttentionBackward(Tensor tensor, Tensor tensor2, Tensor tensor3, Tensor tensor4) {
        GPUOP.getInstance().bmm(this.weights.getGpuData(), tensor4.getGpuData(), tensor3.getGrad().getGpuData(), this.weights.number * this.weights.channel, this.weights.width, tensor4.width, this.weights.height, 1, 0, 1.0f, 0.0f);
        GPUOP.getInstance().bmm(tensor4.getGpuData(), tensor3.getGpuData(), this.scores.getGpuData(), tensor4.number * tensor4.channel, tensor4.height, this.weights.width, tensor4.width, 0, 1, 1.0f, 0.0f);
        this.softmax.backward_noloss(this.weights, this.scores, this.scores);
        float sqrt = (float) (1.0d / Math.sqrt(this.dk));
        GPUOP.getInstance().bmm(this.scores.getGpuData(), tensor.getGpuData(), tensor2.getGrad().getGpuData(), this.scores.number * this.scores.channel, this.scores.width, tensor.width, this.scores.height, 1, 0, sqrt, 0.0f);
        GPUOP.getInstance().bmm(this.scores.getGpuData(), tensor2.getGpuData(), tensor.getGrad().getGpuData(), this.scores.number * this.scores.channel, this.scores.height, tensor2.width, this.scores.width, 0, 1, sqrt, 0.0f);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        if (this.layer_norm) {
            this.lnLayer.back(this.delta);
            this.oLinerLayer.back(this.lnLayer.diff, this.ot);
        } else {
            this.oLinerLayer.back(this.delta, this.ot);
        }
        this.ot.view(this.batchSize, this.time, this.headNum, this.dk);
        TensorOP.permute(this.ot, this.attn_outputs, new int[]{0, 2, 1, 3});
        int[] shape = this.qLinerLayer.getOutput().shape();
        int[] shape2 = this.kLinerLayer.getOutput().shape();
        int[] shape3 = this.vLinerLayer.getOutput().shape();
        scaledDotProductAttentionBackward(this.qt, this.kt, this.vt, this.attn_outputs, this.qLinerLayer.getOutput(), this.kLinerLayer.getOutput(), this.vLinerLayer.getOutput());
        this.qt.view(shape);
        this.kt.view(shape2);
        this.vt.view(shape3);
        TensorOP.permute(this.qLinerLayer.getOutput(), this.qt, new int[]{0, 2, 1, 3});
        TensorOP.permute(this.kLinerLayer.getOutput(), this.kt, new int[]{0, 2, 1, 3});
        TensorOP.permute(this.vLinerLayer.getOutput(), this.vt, new int[]{0, 2, 1, 3});
        Tensor view = this.qt.view(this.batchSize * this.time, 1, 1, this.headNum * this.dk);
        Tensor view2 = this.kt.view(this.batchSize * this.time, 1, 1, this.headNum * this.dk);
        Tensor view3 = this.vt.view(this.batchSize * this.time, 1, 1, this.headNum * this.dk);
        this.qLinerLayer.back(view);
        this.kLinerLayer.back(view2);
        this.vLinerLayer.back(view3);
        TensorOP.add(this.qLinerLayer.diff, this.kLinerLayer.diff, this.qLinerLayer.diff);
        TensorOP.add(this.qLinerLayer.diff, this.vLinerLayer.diff, this.qLinerLayer.diff);
        TensorOP.add(this.qLinerLayer.diff, this.lnLayer.diff, this.qLinerLayer.diff);
        this.diff = this.qLinerLayer.diff;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init();
        setInput(tensor);
        output();
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        init();
        setInput(tensor);
        output(tensor2);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.qLinerLayer.update();
        this.kLinerLayer.update();
        this.vLinerLayer.update();
        this.oLinerLayer.update();
        if (this.layer_norm) {
            this.lnLayer.update();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.mutli_head_attention;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    public Tensor getWeights() {
        return this.weights;
    }

    public static void main(String[] strArr) {
        Transformer transformer = new Transformer();
        transformer.number = 3 * 5;
        float[] order = RandomUtils.order(3 * 5 * 4, 0.1f, 0.1f);
        Tensor triu = ENTokenizer.triu(3, 2, 5, 5, 1.0f);
        Tensor tensor = new Tensor(3 * 5, 1, 1, 4, order, true);
        float[] val = MatrixUtils.val(3 * 5 * 4, 1.0f);
        float[] val2 = MatrixUtils.val(3 * 5 * 4, 1.0f);
        Tensor tensor2 = new Tensor(3 * 5, 1, 1, 4, val, true);
        MultiHeadAttentionLayer multiHeadAttentionLayer = new MultiHeadAttentionLayer(4, 2, 5, false, true, transformer);
        for (int i = 0; i < 10; i++) {
            multiHeadAttentionLayer.forward(tensor, triu);
            multiHeadAttentionLayer.getOutput().showShape();
            multiHeadAttentionLayer.getOutput().showDM();
            multiHeadAttentionLayer.back(tensor2);
            multiHeadAttentionLayer.diff.showDM();
            tensor2.copyData(val2);
        }
    }
}
