package com.omega.engine.ad.op.sign;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.SignOP;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.ad.op.gpu.OPKernel;

/* loaded from: input_file:com/omega/engine/ad/op/sign/ScalarDivOP.class */
public class ScalarDivOP extends SignOP {
    private static final long serialVersionUID = 3087002822041265440L;
    public static ScalarDivOP op = null;
    public static final OPType opt = OPType.scalarDivision;

    public static ScalarDivOP getInstance() {
        if (op == null) {
            op = new ScalarDivOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor output = tape.getOutput();
        TensorOP.div(tape.getScalar(), x, output);
        if (x.isRequiresGrad()) {
            output.setRequiresGrad(true);
        }
        return output;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor x = tape.getX();
        if (x.isRequiresGrad()) {
            if (x.getGrad().isHasGPU()) {
                OPKernel.getInstance().div_scalar_bGrad_gpu(tensor, tape.getScalar(), x, x.getGrad());
            } else {
                bGrad(tensor.data, tape.getScalar(), x.data, x.getGrad().data);
            }
        }
    }

    public static void bGrad(float[] fArr, float f, float[] fArr2, float[] fArr3) {
        for (int i = 0; i < fArr.length; i++) {
            int i2 = i;
            fArr3[i2] = fArr3[i2] + ((((-1.0f) * fArr[i]) * f) / (fArr2[i] * fArr2[i]));
        }
    }
}
