package com.omega.engine.nn.layer;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.GPUOP;
import com.omega.engine.nn.layer.gpu.AttentionKernel;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.Transformer;
import com.omega.engine.updater.UpdaterFactory;

/* loaded from: input_file:com/omega/engine/nn/layer/FastCausalSelfAttentionLayer.class */
public class FastCausalSelfAttentionLayer extends Layer {
    private int time;
    private int headNum;
    private int embedDim;
    private int dk;
    private boolean bias;
    private FullyLayer qkvLinerLayer;
    private FullyLayer oLinerLayer;
    private DropoutLayer dropoutLayer;
    private DropoutLayer dropoutLayer2;
    private BaseKernel baseKernel;
    private AttentionKernel attentionKernel;
    private Tensor qt;
    private Tensor kt;
    private Tensor vt;
    private Tensor dqt;
    private Tensor dkt;
    private Tensor dvt;
    private Tensor vaccum;
    private Tensor preatt;
    private Tensor attn;
    private Tensor oi;
    private Tensor dvaccum;
    private Tensor dattn;
    private Tensor dpreatt;
    private Tensor dqkv;
    private int batchSize = 1;
    private boolean dropout;

    public FastCausalSelfAttentionLayer(int i, int i2, int i3, boolean z, boolean z2) {
        this.headNum = 1;
        this.embedDim = 0;
        this.dk = 0;
        this.bias = false;
        this.dropout = false;
        this.bias = z;
        this.time = i3;
        this.embedDim = i;
        this.headNum = i2;
        this.dk = i / i2;
        this.bias = z;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i;
        this.dropout = z2;
        initLayers();
    }

    public FastCausalSelfAttentionLayer(int i, int i2, int i3, boolean z, boolean z2, Network network) {
        this.headNum = 1;
        this.embedDim = 0;
        this.dk = 0;
        this.bias = false;
        this.dropout = false;
        this.bias = z;
        this.network = network;
        if (this.updater == null) {
            setUpdater(UpdaterFactory.create(network.updater, network.updaterParams));
        }
        this.time = i3;
        this.embedDim = i;
        this.headNum = i2;
        this.dk = i / i2;
        this.bias = z;
        this.oChannel = 1;
        this.oHeight = 1;
        this.oWidth = i;
        this.dropout = z2;
        initLayers();
    }

    public void initLayers() {
        this.qkvLinerLayer = new FullyLayer(this.embedDim, 3 * this.embedDim, this.bias, this.network);
        this.qkvLinerLayer.weight = new Tensor(1, 1, this.embedDim, 3 * this.embedDim, RandomUtils.uniform(this.embedDim * 3 * this.embedDim, 0.0f, 0.02f), true);
        this.oLinerLayer = new FullyLayer(this.embedDim, this.embedDim, this.bias, this.network);
        this.oLinerLayer.weight = new Tensor(1, 1, this.embedDim, this.embedDim, RandomUtils.uniform(this.embedDim * this.embedDim, 0.0f, 0.02f), true);
        if (this.dropout) {
            this.dropoutLayer = new DropoutLayer(0.1f, this.network);
            this.dropoutLayer2 = new DropoutLayer(0.1f, this.oLinerLayer);
        }
        if (this.baseKernel == null) {
            this.baseKernel = new BaseKernel();
        }
        if (this.attentionKernel == null) {
            this.attentionKernel = new AttentionKernel();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void init() {
        this.number = this.network.number;
        this.time = this.network.time;
        this.batchSize = this.number / this.time;
        if (this.preatt != null && this.preatt.number == this.batchSize && this.preatt.width == this.time) {
            return;
        }
        this.qt = Tensor.createTensor(this.qt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.kt = Tensor.createTensor(this.kt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.vt = Tensor.createTensor(this.vt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.preatt = Tensor.createTensor(this.preatt, this.batchSize, this.headNum, this.time, this.time, true);
        this.attn = Tensor.createTensor(this.attn, this.batchSize, this.headNum, this.time, this.time, true);
        this.vaccum = Tensor.createTensor(this.vaccum, this.batchSize, this.headNum, this.time, this.dk, true);
        this.oi = Tensor.createTensor(this.oi, this.batchSize, this.time, this.headNum, this.dk, true);
    }

    public void init(Tensor tensor) {
        this.number = tensor.number;
        this.time = this.network.time;
        this.batchSize = this.number / this.time;
        if (this.preatt != null && this.preatt.number == this.batchSize && this.preatt.width == this.time) {
            return;
        }
        this.qt = Tensor.createTensor(this.qt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.kt = Tensor.createTensor(this.kt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.vt = Tensor.createTensor(this.vt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.preatt = Tensor.createTensor(this.preatt, this.batchSize, this.headNum, this.time, this.time, true);
        this.attn = Tensor.createTensor(this.attn, this.batchSize, this.headNum, this.time, this.time, true);
        this.vaccum = Tensor.createTensor(this.vaccum, this.batchSize, this.headNum, this.time, this.dk, true);
        this.oi = Tensor.createTensor(this.oi, this.batchSize * this.time, 1, 1, this.embedDim, true);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initBack() {
        if (this.dvaccum != null) {
            this.dqkv.clearGPU();
            this.dvaccum.clearGPU();
            return;
        }
        this.dvaccum = Tensor.createTensor(this.dvaccum, this.batchSize, this.headNum, this.time, this.dk, true);
        this.dqt = Tensor.createTensor(this.dqt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.dkt = Tensor.createTensor(this.dkt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.dvt = Tensor.createTensor(this.dvt, this.batchSize, this.headNum, this.time, this.dk, true);
        this.dattn = Tensor.createTensor(this.dattn, this.batchSize, this.headNum, this.time, this.time, true);
        this.dpreatt = Tensor.createTensor(this.dpreatt, this.batchSize, this.headNum, this.time, this.time, true);
        this.dqkv = Tensor.createTensor(this.dqkv, this.number, 1, 1, 3 * this.embedDim, true);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initParam() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void output() {
        this.qkvLinerLayer.forward(this.input);
        this.attentionKernel.permute(this.qkvLinerLayer.getOutput(), this.qt, this.kt, this.vt, this.batchSize, this.time, this.headNum, this.dk);
        scaledDotProductAttention(this.qt, this.kt, this.vt);
        this.attentionKernel.unpermute(this.vaccum, this.oi, this.batchSize, this.time, this.headNum, this.dk);
        this.oLinerLayer.forward(this.oi);
        this.output = this.oLinerLayer.getOutput();
        if (this.dropout) {
            this.dropoutLayer2.forward(this.oLinerLayer.getOutput());
            this.output = this.dropoutLayer2.getOutput();
        }
    }

    public void scaledDotProductAttention(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        float sqrt = (float) (1.0d / Math.sqrt(this.dk));
        GPUOP.getInstance().bmm(1, 0, this.time, this.time, this.dk, 1.0f, tensor2.getGpuData(), this.dk, this.time * this.dk, tensor.getGpuData(), this.dk, this.time * this.dk, 0.0f, this.preatt.getGpuData(), this.time, this.time * this.time, this.batchSize * this.headNum);
        this.attentionKernel.scale(this.preatt, sqrt, this.batchSize, this.headNum, this.time);
        this.attentionKernel.softmax_forward(this.preatt, this.attn, this.batchSize, this.headNum, this.time);
        Tensor tensor4 = this.attn;
        if (this.dropout) {
            this.dropoutLayer.forward(this.attn);
            tensor4 = this.dropoutLayer.getOutput();
        }
        GPUOP.getInstance().bmm(0, 0, this.dk, this.time, this.time, 1.0f, tensor3.getGpuData(), this.dk, this.time * this.dk, tensor4.getGpuData(), this.time, this.time * this.time, 0.0f, this.vaccum.getGpuData(), this.dk, this.time * this.dk, this.batchSize * this.headNum);
    }

    public void scaledDotProductAttentionBackward() {
        Tensor tensor = this.attn;
        if (this.dropout) {
            tensor = this.dropoutLayer.getOutput();
        }
        GPUOP.getInstance().bmm(1, 0, this.time, this.time, this.dk, 1.0f, this.vt.getGpuData(), this.dk, this.time * this.dk, this.dvaccum.getGpuData(), this.dk, this.time * this.dk, 0.0f, this.dattn.getGpuData(), this.time, this.time * this.time, this.batchSize * this.headNum);
        GPUOP.getInstance().bmm(0, 1, this.dk, this.time, this.time, 1.0f, this.dvaccum.getGpuData(), this.dk, this.time * this.dk, tensor.getGpuData(), this.time, this.time * this.time, 0.0f, this.dvt.getGpuData(), this.dk, this.time * this.dk, this.batchSize * this.headNum);
        if (this.dropout) {
            this.dropoutLayer.back(this.dattn);
            this.dattn = this.dropoutLayer.diff;
        }
        this.attentionKernel.softmax_backward(this.dpreatt, this.dattn, this.attn, this.batchSize, this.time, this.embedDim, this.headNum);
        GPUOP.getInstance().bmm(0, 0, this.dk, this.time, this.time, 1.0f, this.kt.getGpuData(), this.dk, this.time * this.dk, this.dpreatt.getGpuData(), this.time, this.time * this.time, 0.0f, this.dqt.getGpuData(), this.dk, this.time * this.dk, this.batchSize * this.headNum);
        GPUOP.getInstance().bmm(0, 1, this.dk, this.time, this.time, 1.0f, this.qt.getGpuData(), this.dk, this.time * this.dk, this.dpreatt.getGpuData(), this.time, this.time * this.time, 0.0f, this.dkt.getGpuData(), this.dk, this.time * this.dk, this.batchSize * this.headNum);
    }

    @Override // com.omega.engine.nn.layer.Layer
    public Tensor getOutput() {
        return this.output;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void diff() {
        if (this.dropout) {
            this.dropoutLayer2.back(this.delta);
            this.oLinerLayer.back(this.dropoutLayer2.diff, this.oi);
        } else {
            this.oLinerLayer.back(this.delta, this.oi);
        }
        this.attentionKernel.unpermute_backward(this.dvaccum, this.oi, this.batchSize, this.time, this.headNum, this.dk);
        scaledDotProductAttentionBackward();
        this.attentionKernel.permute_backward(this.dqkv, this.dqt, this.dkt, this.dvt, this.batchSize, this.time, this.headNum, this.dk);
        this.qkvLinerLayer.back(this.dqkv);
        this.diff = this.qkvLinerLayer.diff;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward() {
        init();
        setInput();
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back() {
        initBack();
        setDelta();
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void forward(Tensor tensor) {
        init(tensor);
        setInput(tensor);
        output();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void back(Tensor tensor) {
        initBack();
        setDelta(tensor);
        diff();
        if (this.network.GRADIENT_CHECK) {
            gradientCheck();
        }
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void update() {
        this.qkvLinerLayer.update();
        this.oLinerLayer.update();
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void showDiff() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public LayerType getLayerType() {
        return LayerType.mutli_head_attention;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public float[][][][] output(float[][][][] fArr) {
        return (float[][][][]) null;
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void initCache() {
    }

    @Override // com.omega.engine.nn.layer.Layer
    public void backTemp() {
    }

    public static void main(String[] strArr) {
        Transformer transformer = new Transformer();
        transformer.number = 3 * 3;
        transformer.time = 3;
        Tensor tensor = new Tensor(3 * 3, 1, 1, 4, RandomUtils.order(3 * 3 * 4, 0.1f, 0.1f), true);
        float[] val = MatrixUtils.val(3 * 3 * 4, 1.0f);
        MatrixUtils.val(3 * 3 * 4, 1.0f);
        Tensor tensor2 = new Tensor(3 * 3, 1, 1, 4, val, true);
        FastCausalSelfAttentionLayer fastCausalSelfAttentionLayer = new FastCausalSelfAttentionLayer(4, 2, 3, false, false, transformer);
        for (int i = 0; i < 10; i++) {
            fastCausalSelfAttentionLayer.forward(tensor);
            fastCausalSelfAttentionLayer.getOutput().showShape();
            fastCausalSelfAttentionLayer.getOutput().showDM();
            fastCausalSelfAttentionLayer.back(tensor2);
            fastCausalSelfAttentionLayer.diff.showDM();
        }
    }

    public static boolean same(Tensor tensor, Tensor tensor2) {
        float[] syncHost = tensor.syncHost();
        float[] syncHost2 = tensor2.syncHost();
        for (int i = 0; i < syncHost.length; i++) {
            if (syncHost[i] != syncHost2[i]) {
                System.out.println(syncHost[i] + ":" + syncHost2[i] + "[" + i + "]");
                return false;
            }
        }
        return true;
    }
}
