package com.omega.engine.ad.op.sign;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.SignOP;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.ad.op.gpu.OPKernel;

/* loaded from: input_file:com/omega/engine/ad/op/sign/DivOP.class */
public class DivOP extends SignOP {
    private static final long serialVersionUID = 6114922229588936622L;
    public static DivOP op = null;
    public static final OPType opt = OPType.division;

    public static DivOP getInstance() {
        if (op == null) {
            op = new DivOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor y = tape.getY();
        Tensor output = tape.getOutput();
        if (y != null) {
            TensorOP.div(x, y, output);
        } else {
            TensorOP.div(x, tape.getScalar(), output);
        }
        if (x.isRequiresGrad() || (y != null && y.isRequiresGrad())) {
            output.setRequiresGrad(true);
        }
        return output;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor x = tape.getX();
        Tensor y = tape.getY();
        if (x.isRequiresGrad()) {
            if (y != null) {
                TensorOP.divPlus(tensor, y, x.getGrad());
            } else {
                TensorOP.divPlus(tensor, tape.getScalar(), x.getGrad());
            }
        }
        if (y == null || !y.isRequiresGrad()) {
            return;
        }
        if (y.getGrad().isHasGPU()) {
            OPKernel.getInstance().div_bGrad_gpu(tensor, x, y, y.getGrad());
        } else {
            bGrad(tensor.data, x.data, y.data, y.getGrad().data);
        }
    }

    public static void bGrad(float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4) {
        for (int i = 0; i < fArr.length; i++) {
            int i2 = i;
            fArr4[i2] = fArr4[i2] + ((((-1.0f) * fArr[i]) * fArr2[i]) / (fArr3[i] * fArr3[i]));
        }
    }
}
