package com.omega.engine.ad;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.PrintUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.op.OP;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.data.GetOP;
import com.omega.engine.ad.op.data.SetOP;
import com.omega.engine.ad.op.functions.ATanOP;
import com.omega.engine.ad.op.functions.ClampOP;
import com.omega.engine.ad.op.functions.CosOP;
import com.omega.engine.ad.op.functions.ExpOP;
import com.omega.engine.ad.op.functions.LogOP;
import com.omega.engine.ad.op.functions.MaxOP;
import com.omega.engine.ad.op.functions.MaximumOP;
import com.omega.engine.ad.op.functions.MinimumOP;
import com.omega.engine.ad.op.functions.PowOP;
import com.omega.engine.ad.op.functions.SinOP;
import com.omega.engine.ad.op.functions.SumOP;
import com.omega.engine.ad.op.functions.TanOP;
import com.omega.engine.ad.op.functions.TransposeOP;
import com.omega.engine.ad.op.sign.AddOP;
import com.omega.engine.ad.op.sign.DivOP;
import com.omega.engine.ad.op.sign.DotOP;
import com.omega.engine.ad.op.sign.MulOP;
import com.omega.engine.ad.op.sign.ScalarDivOP;
import com.omega.engine.ad.op.sign.ScalarSubOP;
import com.omega.engine.ad.op.sign.SubOP;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.example.yolo.utils.YoloImageUtils;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/omega/engine/ad/Graph.class */
public class Graph {
    private List<Tape> tapes = new ArrayList();
    public int tapeIndex = 0;
    private boolean lock = false;

    public void start() {
        this.tapeIndex = 0;
    }

    public void lock() {
        this.lock = true;
    }

    public void unlock() {
        this.lock = false;
    }

    public void showGraph() {
        for (int i = 0; i < this.tapes.size(); i++) {
            System.out.println(i + ":[" + this.tapes.get(i).getOp().getOpType() + "]");
            System.out.println("x1:[" + this.tapes.get(i).getX() + "]|x2:[" + this.tapes.get(i).getY() + "]|out:[" + this.tapes.get(i).getOutput() + "]");
        }
    }

    public void reset() {
        this.tapes.clear();
    }

    public void clearGrad() {
        for (int i = 0; i < this.tapes.size(); i++) {
            this.tapes.get(i).zeroGrad();
        }
    }

    public void add(Tape tape) {
        this.tapes.add(tape);
    }

    public Tape getTape(OP op, Tensor tensor, Tensor tensor2, float f, float f2, int[] iArr) {
        Tape tape;
        if (this.lock) {
            tape = this.tapes.get(this.tapeIndex);
            if (tape.getOp().getOpType().equals(OPType.sum)) {
                tape.getOutput().fill(0.0f);
            }
            this.tapeIndex++;
        } else {
            tape = new Tape(op, tensor, tensor2, f, f2, iArr, this);
            checkSubTape(tensor, tensor2);
            add(tape);
        }
        return tape;
    }

    public void checkSubTape(Tensor tensor, Tensor tensor2) {
        for (Tape tape : this.tapes) {
            if (!tape.isSub() && (tape.getX() == tensor || tape.getY() == tensor || tape.getX() == tensor2 || tape.getY() == tensor2 || tape.getOutput() == tensor || tape.getOutput() == tensor2)) {
                tape.setSub(true);
            }
        }
    }

    public Tensor OP(OPType oPType, Tensor tensor, Tensor tensor2) {
        OP op = null;
        switch (AnonymousClass1.$SwitchMap$com$omega$engine$ad$op$OPType[oPType.ordinal()]) {
            case 1:
                op = AddOP.getInstance();
                break;
            case 2:
                op = SubOP.getInstance();
                break;
            case 3:
                op = MulOP.getInstance();
                break;
            case 4:
                op = DivOP.getInstance();
                break;
            case 5:
                op = MaximumOP.getInstance();
                break;
            case 6:
                op = MinimumOP.getInstance();
                break;
            case YoloImageUtils.GRID_SIZE /* 7 */:
                op = DotOP.getInstance();
                break;
        }
        if (op == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(op, tensor, tensor2, 0.0f, 0.0f, null).forward();
        forward.setG(this);
        return forward;
    }

    public Tensor OP(OPType oPType, Tensor tensor, float f) {
        OP op = null;
        switch (AnonymousClass1.$SwitchMap$com$omega$engine$ad$op$OPType[oPType.ordinal()]) {
            case 1:
                op = AddOP.getInstance();
                break;
            case 2:
                op = SubOP.getInstance();
                break;
            case 3:
                op = MulOP.getInstance();
                break;
            case 4:
                op = DivOP.getInstance();
                break;
            case 8:
                op = ScalarSubOP.getInstance();
                break;
            case 9:
                op = ScalarDivOP.getInstance();
                break;
            case 10:
                op = PowOP.getInstance();
                break;
        }
        if (op == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(op, tensor, null, f, 0.0f, null).forward();
        forward.setG(this);
        return forward;
    }

    public Tensor OP(OPType oPType, Tensor tensor, float f, float f2) {
        ClampOP clampOP = null;
        switch (oPType) {
            case clamp:
                clampOP = ClampOP.getInstance();
                break;
        }
        if (clampOP == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(clampOP, tensor, null, f, f2, null).forward();
        forward.setG(this);
        return forward;
    }

    public Tensor OP(OPType oPType, Tensor tensor) {
        OP op = null;
        switch (oPType) {
            case log:
                op = LogOP.getInstance();
                break;
            case sin:
                op = SinOP.getInstance();
                break;
            case cos:
                op = CosOP.getInstance();
                break;
            case tan:
                op = TanOP.getInstance();
                break;
            case atan:
                op = ATanOP.getInstance();
                break;
            case exp:
                op = ExpOP.getInstance();
                break;
            case transpose:
                op = TransposeOP.getInstance();
                break;
        }
        if (op == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(op, tensor, null, 0.0f, 0.0f, null).forward();
        forward.setG(this);
        return forward;
    }

    public Tensor OP(OPType oPType, Tensor tensor, int[] iArr) {
        OP op = null;
        switch (oPType) {
            case get:
                op = GetOP.getInstance();
                break;
            case sum:
                op = SumOP.getInstance();
                break;
            case max:
                op = MaxOP.getInstance();
                break;
        }
        if (op == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(op, tensor, null, 0.0f, 0.0f, iArr).forward();
        forward.setG(this);
        return forward;
    }

    public Tensor OP(OPType oPType, Tensor tensor, Tensor tensor2, int[] iArr) {
        SetOP setOP = null;
        switch (oPType) {
            case set:
                setOP = SetOP.getInstance();
                break;
        }
        if (setOP == null) {
            throw new RuntimeException("the op is not support.");
        }
        Tensor forward = getTape(setOP, tensor, tensor2, 0.0f, 0.0f, iArr).forward();
        forward.setG(this);
        return forward;
    }

    public void backward(Tensor tensor) {
        this.lock = true;
        for (int size = this.tapes.size() - 1; size >= 0; size--) {
            Tape tape = this.tapes.get(size);
            if (size == this.tapes.size() - 1) {
                tape.backward(tensor);
            } else {
                tape.backward();
            }
        }
        this.tapeIndex = 0;
    }

    public void backward() {
        this.lock = true;
        for (int size = this.tapes.size() - 1; size >= 0; size--) {
            Tape tape = this.tapes.get(size);
            if (!tape.isSub()) {
                tape.getOutput().getGrad().fill(1.0f);
            }
            tape.backward();
        }
        this.tapeIndex = 0;
    }

    public void formula1() {
        int i = 1 * 1 * 1 * 1;
        Tensor tensor = new Tensor(1, 1, 1, 1, MatrixUtils.val(i, 2.0f));
        Tensor tensor2 = new Tensor(1, 1, 1, 1, MatrixUtils.val(i, 5.0f));
        tensor.setRequiresGrad(true);
        tensor2.setRequiresGrad(true);
        for (int i2 = 0; i2 < 10; i2++) {
            clearGrad();
            Tensor sub = tensor.log().add(tensor.mul(tensor2)).sub(tensor2.sin());
            backward();
            System.out.println("z:" + JsonUtils.toJson(sub.data));
            System.out.println("dx:" + JsonUtils.toJson(tensor.getGrad()));
            System.out.println("dy:" + JsonUtils.toJson(tensor2.getGrad()));
        }
    }

    public void sigmoid_gpu(Tensor tensor, Tensor tensor2) {
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        tensor2.hostToDevice();
        long nanoTime = System.nanoTime();
        Tensor add = tensor.get(1, 0, 2).mul(-1.0f).exp().add(1.0f).scalarDiv(1.0f).add(tensor.get(1, 2, 2)).add(tensor2.get(1, 4, 2).sub(tensor.get(1, 4, 2).mul(-1.0f).exp().add(1.0f).scalarDiv(1.0f)).pow(2.0f));
        backward();
        add.syncHost();
        tensor.getGrad().syncHost();
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
    }

    public static void get_gpu() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(64, 128, 32, 32, MatrixUtils.order(64 * 128 * 32 * 32, 0, 1), true);
        long nanoTime = System.nanoTime();
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        Tensor pow = tensor.get(1, 1, 10).pow(2.0f);
        Tensor tensor2 = tensor.get(1, 14, 10);
        graph.showGraph();
        graph.backward();
        pow.syncHost();
        tensor2.syncHost();
        tensor.getGrad().syncHost();
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
    }

    public static void pow_gpu() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(2, 3, 5, 5, MatrixUtils.order(2 * 3 * 5 * 5, 0, 1), true);
        long nanoTime = System.nanoTime();
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        Tensor pow = tensor.pow(3.0f);
        graph.showGraph();
        graph.backward();
        pow.syncHost();
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
    }

    public static void show() {
        Tensor tensor = new Tensor(10, 5, 5, 5, MatrixUtils.order(10 * 5 * 5 * 5, 0, 1));
        Tensor tensor2 = new Tensor(tensor.number, 2, tensor.height, tensor.width, tensor.isHasGPU());
        for (int i = 0; i < tensor2.dataLength; i++) {
            int i2 = ((tensor2.dataLength / 10) / 5) / 5;
            tensor2.data[i] = tensor.data[((i / ((i2 * 5) * 5)) * 5 * 5 * 5) + (((((i / 5) / 5) % i2) + 1) * 5 * 5) + (((i / 5) % 5) * 5) + (i % 5)];
        }
        PrintUtils.printImage(tensor2);
    }

    public static void yolov3_loss() {
        int i = 3 * 18 * 5 * 5;
        Graph graph = new Graph();
        Tensor tensor = new Tensor(3, 18, 5, 5, MatrixUtils.val(i, 0.6f), true);
        Tensor tensor2 = new Tensor(3, 18, 5, 5, MatrixUtils.val(i, 1.0f), true);
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        tensor2.hostToDevice();
        long nanoTime = System.nanoTime();
        Tensor BCELoss = BCELoss(sigmoid(tensor.get(1, 0, 2)), tensor2.get(1, 0, 2));
        Tensor MSELoss = MSELoss(tensor.get(1, 2, 2), tensor2.get(1, 2, 2));
        Tensor BCELoss2 = BCELoss(sigmoid(tensor.get(1, 4, 2)), tensor2.get(1, 4, 2));
        Tensor BCELoss3 = BCELoss(sigmoid(tensor.get(1, 6, 2)), tensor2.get(1, 6, 2));
        Tensor add = BCELoss.add(MSELoss).add(BCELoss2).add(BCELoss3).add(MSELoss(tensor.get(1, 8, 2), tensor2.get(1, 8, 2))).add(BCELoss(sigmoid(tensor.get(1, 10, 2)), tensor2.get(1, 10, 2)));
        graph.backward();
        add.syncHost();
        System.out.println("z:" + JsonUtils.toJson(add.data));
        tensor.getGrad().syncHost();
        System.out.println("dx:" + JsonUtils.toJson(tensor.getGrad()));
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        PrintUtils.printImage(add);
        PrintUtils.printImage(tensor.getGrad());
    }

    public static Tensor sigmoid(Tensor tensor) {
        return tensor.mul(-1.0f).exp().add(1.0f).scalarDiv(1.0f);
    }

    public static Tensor tanh(Tensor tensor) {
        Tensor exp = tensor.mul(-2.0f).exp();
        return exp.scalarSub(1.0f).div(exp.add(1.0f));
    }

    public static Tensor MSELoss(Tensor tensor, Tensor tensor2) {
        return tensor.sub(tensor2).pow(2.0f);
    }

    public static Tensor BCELoss(Tensor tensor, Tensor tensor2) {
        return tensor2.mul(-1.0f).mul(tensor.log()).sub(tensor2.scalarSub(1.0f).mul(tensor.scalarSub(1.0f).log()));
    }

    public static void multiLabelSoftMarginLoss() {
        int i = 2 * 1 * 1 * 4;
        int i2 = 1 * 1 * 4;
        Graph graph = new Graph();
        Tensor tensor = new Tensor(2, 1, 1, 4, new float[]{0.2f, 0.5f, 0.0f, 0.0f, 0.1f, 0.5f, 0.0f, 0.8f}, true);
        Tensor tensor2 = new Tensor(2, 1, 1, 4, new float[]{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 1.0f}, true);
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        tensor2.hostToDevice();
        for (int i3 = 0; i3 < 20; i3++) {
            long nanoTime = System.nanoTime();
            Tensor div = tensor2.mul(sigmoid(tensor).log()).add(sigmoid(tensor.mul(-1.0f)).log().mul(tensor2.scalarSub(1.0f))).mul(-1.0f).sum(1).div(i2).sum(0).div(tensor.number);
            graph.clearGrad();
            graph.backward();
            div.syncHost();
            System.out.println("loss:" + JsonUtils.toJson(div.data));
            tensor.getGrad().syncHost();
            System.out.println("dx:" + JsonUtils.toJson(tensor.getGrad()));
            System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
            PrintUtils.printImage(tensor.getGrad());
        }
    }

    public static void multiLabelSoftMarginLoss2() {
        int i = 64 * 128 * 32 * 32;
        int i2 = 128 * 32 * 32;
        float[] gaussianRandom = RandomUtils.gaussianRandom(i, 0.1f);
        float[] gaussianRandom2 = RandomUtils.gaussianRandom(i, 0.1f);
        Graph graph = new Graph();
        Tensor tensor = new Tensor(64, 128, 32, 32, gaussianRandom, true, graph);
        Tensor tensor2 = new Tensor(64, 128, 32, 32, gaussianRandom2, true, graph);
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        tensor2.hostToDevice();
        for (int i3 = 0; i3 < 200; i3++) {
            long nanoTime = System.nanoTime();
            graph.start();
            Tensor div = tensor2.mul(sigmoid(tensor).log()).add(sigmoid(tensor.mul(-1.0f)).log().mul(tensor2.scalarSub(1.0f))).mul(-1.0f).sum(1).div(i2).sum(0).div(tensor.number);
            graph.lock = true;
            graph.clearGrad();
            graph.backward();
            div.syncHost();
            System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        }
    }

    public static void sq() {
        int i = 3 * 18 * 5 * 5;
        int i2 = 18 * 5 * 5;
        float[] gaussianRandom = RandomUtils.gaussianRandom(i, 0.1f);
        float[] gaussianRandom2 = RandomUtils.gaussianRandom(i, 0.1f);
        Graph graph = new Graph();
        Tensor tensor = new Tensor(3, 18, 5, 5, gaussianRandom, true);
        Tensor tensor2 = new Tensor(3, 18, 5, 5, gaussianRandom2, true);
        for (int i3 = 0; i3 < 20; i3++) {
            tensor.data = RandomUtils.gaussianRandom(i, 0.1f);
            tensor2.data = RandomUtils.gaussianRandom(i, 0.1f);
            sq_back_cpu(tensor, tensor2);
            tensor.setRequiresGrad(true);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            tensor2.sub(tensor).pow(2.0f).div(2.0f);
            graph.clearGrad();
            graph.backward();
            tensor.getGrad().syncHost();
            System.out.println("dx_gpu:" + JsonUtils.toJson(tensor.getGrad().data));
        }
    }

    public static void sq_back_cpu(Tensor tensor, Tensor tensor2) {
        Tensor tensor3 = new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, true);
        for (int i = 0; i < tensor.getDataLength(); i++) {
            tensor3.data[i] = tensor.data[i] - tensor2.data[i];
        }
        System.out.println("dx_cpu:" + JsonUtils.toJson(tensor3.data));
    }

    public static void sum() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(3, 18, 5, 5, MatrixUtils.val(3 * 18 * 5 * 5, 0.6f), true, graph);
        tensor.setRequiresGrad(true);
        tensor.hostToDevice();
        Tensor sum = tensor.sum(1);
        graph.backward();
        sum.syncHost();
        System.out.println("z:" + JsonUtils.toJson(sum.data));
        tensor.getGrad().syncHost();
        System.out.println("dx:" + JsonUtils.toJson(tensor.getGrad()));
        PrintUtils.printImage(tensor.getGrad());
    }

    public static void maximum() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(1, 1, 1, 5, new float[]{0.1f, 1.0f, 0.06f, -1.0f, 1.3f}, true, graph);
        Tensor tensor2 = new Tensor(1, 1, 1, 5, new float[]{-0.1f, 1.0f, 0.07f, -1.2f, 0.003f}, true, graph);
        tensor.setRequiresGrad(true);
        tensor2.setRequiresGrad(true);
        tensor.maximum(tensor2).showDM();
        graph.clearGrad();
        graph.backward();
        tensor.getGrad().showDM();
        tensor2.getGrad().showDM();
    }

    public static void minimum() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(1, 1, 1, 5, new float[]{0.1f, 1.0f, 0.06f, -1.0f, 1.3f}, true, graph);
        Tensor tensor2 = new Tensor(1, 1, 1, 5, new float[]{-0.1f, 1.0f, 0.07f, -1.2f, 0.003f}, true, graph);
        tensor.setRequiresGrad(true);
        tensor2.setRequiresGrad(true);
        tensor.minimum(tensor2).showDM();
        graph.clearGrad();
        graph.backward();
        tensor.getGrad().showDM();
        tensor2.getGrad().showDM();
    }

    public static void Lciou() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(1, 4, 1, 1, new float[]{0.5f, 0.02f, 0.3f, 0.6f}, true, graph);
        Tensor tensor2 = new Tensor(1, 4, 1, 1, new float[]{0.3f, 0.2f, 0.03f, 0.12f}, true, graph);
        tensor.setRequiresGrad(true);
        Tensor tensor3 = tensor.get(1, 0, 1);
        Tensor tensor4 = tensor.get(1, 1, 1);
        Tensor tensor5 = tensor.get(1, 2, 1);
        Tensor tensor6 = tensor.get(1, 3, 1);
        Tensor div = tensor5.div(2.0f);
        Tensor div2 = tensor6.div(2.0f);
        Tensor tensor7 = tensor2.get(1, 0, 1);
        Tensor tensor8 = tensor2.get(1, 1, 1);
        Tensor tensor9 = tensor2.get(1, 2, 1);
        Tensor tensor10 = tensor2.get(1, 3, 1);
        Tensor div3 = tensor9.div(2.0f);
        Tensor div4 = tensor10.div(2.0f);
        Tensor sub = tensor3.sub(div);
        Tensor add = tensor3.add(div);
        Tensor sub2 = tensor4.sub(div2);
        Tensor add2 = tensor4.add(div2);
        Tensor sub3 = tensor7.sub(div3);
        Tensor add3 = tensor7.add(div3);
        Tensor sub4 = tensor8.sub(div4);
        Tensor add4 = tensor8.add(div4);
        Tensor mul = add.minimum(add3).sub(sub.maximum(sub3)).mul(add2.minimum(add4).sub(sub2.maximum(sub4)));
        Tensor div5 = mul.div(tensor5.mul(tensor6).add(tensor9.mul(tensor10)).sub(mul));
        Tensor add5 = add.maximum(add3).sub(sub.minimum(sub3)).pow().add(add2.maximum(add4).sub(sub2.minimum(sub4)).pow());
        Tensor div6 = sub3.add(add3).sub(sub).sub(add).pow().add(sub4.add(add4).sub(sub2).sub(add2).pow()).div(4.0f);
        Tensor mul2 = tensor9.div(tensor10).atan().sub(tensor5.div(tensor6).atan()).pow().mul(0.40528473f);
        Tensor sub5 = div5.sub(div6.div(add5).add(mul2.mul(mul2.div(mul2.sub(div5).add(1.0f + 1.0E-7f)))));
        System.out.println("===================");
        sub5.showDM();
        graph.clearGrad();
        graph.backward();
        System.out.println("==========grad=========");
        tensor.getGrad().showDM();
    }

    public static void atan() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(1, 4, 1, 1, new float[]{0.5f, 0.02f, 0.3f, 0.6f}, true, graph);
        tensor.setRequiresGrad(true);
        Tensor atan = tensor.atan();
        graph.clearGrad();
        graph.backward();
        atan.showDM();
        tensor.getGrad().showDM();
    }

    public static void silu() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(1, 1, 1, 4, new float[]{0.5f, 0.02f, 0.3f, 0.6f}, true, graph);
        tensor.setRequiresGrad(true);
        Tensor sigmoid = sigmoid(tensor);
        sigmoid.showDM();
        Tensor mul = tensor.mul(sigmoid);
        graph.clearGrad();
        graph.backward();
        mul.showDM();
        tensor.getGrad().showDM();
        mul.add(sigmoid.mul(mul.scalarSub(1.0f))).showDM();
    }

    public static void RNN() {
        Graph graph = new Graph();
        int i = 3 * 2;
        Tensor tensor = new Tensor(i, 1, 1, 3, MatrixUtils.order(3 * 2 * 3, 0.0f, 0.1f), true, graph);
        tensor.setRequiresGrad(true);
        Tensor tensor2 = new Tensor(1, 1, 3, 5, MatrixUtils.val(3 * 5, 0.1f), true, graph);
        tensor2.setRequiresGrad(true);
        Tensor tensor3 = new Tensor(1, 1, 5, 5, MatrixUtils.val(5 * 5, 0.2f), true, graph);
        tensor3.setRequiresGrad(true);
        Tensor tensor4 = new Tensor(1, 1, 5, 5, MatrixUtils.val(5 * 5, 0.01f), true, graph);
        tensor4.setRequiresGrad(true);
        Tensor tensor5 = new Tensor(i, 1, 1, 5, true, graph);
        Tensor tensor6 = null;
        int i2 = 0;
        while (i2 < 3) {
            tensor6 = tanh(i2 == 0 ? tensor.get(0, i2, 2).dot(tensor2) : tensor.get(0, i2, 2).dot(tensor2).add(tensor6.dot(tensor3)));
            tensor5.set(tensor6, 0, i2 * 2);
            i2++;
        }
        graph.clearGrad();
        graph.backward();
        System.out.println("x:");
        tensor.showDM();
        System.out.println("out:");
        tensor5.showDM();
        System.out.println("x-grad:");
        tensor.getGrad().showDM();
        System.out.println("w-grad:");
        tensor2.getGrad().showDM();
        System.out.println("u-grad:");
        tensor3.getGrad().showDM();
        System.out.println("v-grad:");
        tensor4.getGrad().showDM();
    }

    public static void selfAttention() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(2, 1, 1, 5, MatrixUtils.order(2 * 5, 0.0f, 0.1f), true, graph);
        tensor.setRequiresGrad(true);
        Tensor tensor2 = new Tensor(1, 1, 5, 5, MatrixUtils.val(5 * 5, 0.1f), true, graph);
        tensor2.setRequiresGrad(true);
        Tensor tensor3 = new Tensor(1, 1, 5, 5, MatrixUtils.val(5 * 5, 0.2f), true, graph);
        tensor3.setRequiresGrad(true);
        Tensor tensor4 = new Tensor(1, 1, 5, 5, MatrixUtils.val(5 * 5, 0.01f), true, graph);
        tensor4.setRequiresGrad(true);
        Tensor linear = linear(tensor, tensor2);
        Tensor linear2 = linear(tensor, tensor3);
        Tensor linear3 = linear(tensor, tensor4);
        Tensor softmax = softmax(linear.dot(linear2.transpose()));
        System.out.println("sf:");
        softmax.showDM();
        System.out.println("v:");
        linear3.showDM();
        Tensor dot = softmax.dot(linear3);
        System.out.println("output:");
        dot.showDM();
        graph.clearGrad();
        graph.backward();
        System.out.println("x-grad:");
        tensor.getGrad().showDM();
        System.out.println("qw-grad:");
        tensor2.getGrad().showDM();
        System.out.println("kw-grad:");
        tensor3.getGrad().showDM();
        System.out.println("vw-grad:");
        tensor4.getGrad().showDM();
    }

    public static Tensor linear(Tensor tensor, Tensor tensor2) {
        return tensor.dot(tensor2);
    }

    public static Tensor softmax(Tensor tensor) {
        Tensor exp = tensor.sub(tensor.max(1)).exp();
        return exp.div(exp.sum(1));
    }

    public static void softmax_test() {
        Graph graph = new Graph();
        Tensor tensor = new Tensor(2, 1, 1, 10, MatrixUtils.order(2 * 10, 0.0f, 0.1f), true, graph);
        tensor.setRequiresGrad(true);
        Tensor softmax = softmax(tensor);
        graph.clearGrad();
        graph.backward(new Tensor(2, 1, 1, 10, MatrixUtils.order(2 * 10, 0.0f, 0.1f), true));
        softmax.showDM();
        softmax.getGrad().showDM();
    }

    public static void main(String[] strArr) {
        try {
            CUDAModules.initContext();
            softmax_test();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            CUDAMemoryManager.free();
        }
    }
}
