package com.omega.engine.ad.op.sign;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.SignOP;
import com.omega.engine.ad.op.TensorOP;

/* loaded from: input_file:com/omega/engine/ad/op/sign/SubOP.class */
public class SubOP extends SignOP {
    private static final long serialVersionUID = -3681016263960474439L;
    public static SubOP op = null;
    public static final OPType opt = OPType.subtraction;

    public static SubOP getInstance() {
        if (op == null) {
            op = new SubOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor y = tape.getY();
        Tensor output = tape.getOutput();
        if (y != null) {
            TensorOP.sub(x, y, output);
        } else {
            TensorOP.sub(x, tape.getScalar(), output);
        }
        if (x.isRequiresGrad() || y.isRequiresGrad()) {
            output.setRequiresGrad(true);
        }
        return output;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor x = tape.getX();
        Tensor y = tape.getY();
        if (x.isRequiresGrad()) {
            TensorOP.mulPlus(tensor, 1.0f, x.getGrad());
        }
        if (y == null || !y.isRequiresGrad()) {
            return;
        }
        TensorOP.mulPlus(tensor, -1.0f, y.getGrad());
    }
}
