package com.omega.engine.ad.op.functions;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.FunctionOP;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.TensorOP;

/* loaded from: input_file:com/omega/engine/ad/op/functions/TransposeOP.class */
public class TransposeOP extends FunctionOP {
    private static final long serialVersionUID = -3857343378511617891L;
    public static TransposeOP op = null;
    public static final OPType opt = OPType.transpose;

    public static TransposeOP getInstance() {
        if (op == null) {
            op = new TransposeOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor output = tape.getOutput();
        TensorOP.transpose(x, output);
        if (x.isRequiresGrad()) {
            output.setRequiresGrad(true);
        }
        return output;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor x = tape.getX();
        if (x.isRequiresGrad()) {
            Tensor tmp = tape.getTmp();
            TensorOP.transpose(tensor, tmp);
            TensorOP.mulPlus(tmp, 1.0f, x.getGrad());
        }
    }

    public static void main(String[] strArr) {
        testPermute();
    }

    public static void testPermute() {
        Tensor tensor = new Tensor(2, 5, 2, 2, RandomUtils.order(2 * 5 * 2 * 2, 1.0f, 0.0f), true);
        Tensor tensor2 = new Tensor(2, 2, 5, 2, true);
        tensor.showDM();
        TensorOP.permute(tensor, tensor2, new int[]{0, 2, 1, 3});
        tensor2.showDM();
    }
}
