package com.omega.engine.loss;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.Graph;

/* loaded from: input_file:com/omega/engine/loss/MultiLabelSoftMargin.class */
public class MultiLabelSoftMargin extends LossFunction {
    private static MultiLabelSoftMargin instance;
    public final LossType lossType = LossType.multiLabel_soft_margin;
    private Graph g = new Graph();

    public static MultiLabelSoftMargin operation() {
        if (instance == null) {
            instance = new MultiLabelSoftMargin();
        }
        return instance;
    }

    public void initGraph(Tensor tensor, Tensor tensor2) {
        if (tensor.getG() == null) {
            tensor.setG(this.g);
        }
        if (tensor2.getG() == null) {
            tensor2.setG(this.g);
        }
    }

    public static Tensor sigmoid(Tensor tensor) {
        return tensor.mul(-1.0f).exp().add(1.0f).scalarDiv(1.0f);
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        initGraph(tensor, tensor2);
        tensor.getG().start();
        tensor.setRequiresGrad(true);
        return tensor2.mul(sigmoid(tensor).log()).add(sigmoid(tensor.mul(-1.0f)).log().mul(tensor2.scalarSub(1.0f))).mul(-1.0f).sum(1).div(tensor.channel * tensor.height * tensor.width).sum(0).div(tensor.number);
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2) {
        tensor.getG().clearGrad();
        tensor.getG().backward();
        return tensor.getGrad();
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] loss(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] diff(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public LossType getLossType() {
        return LossType.multiLabel_soft_margin;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        initGraph(tensor, tensor2);
        tensor.getG().start();
        tensor.setRequiresGrad(true);
        return tensor2.mul(sigmoid(tensor).log()).add(sigmoid(tensor.mul(-1.0f)).log().mul(tensor2.scalarSub(1.0f))).mul(-1.0f).sum(1).div(tensor.channel * tensor.height * tensor.width).sum(0).div(tensor.number);
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        tensor.getG().clearGrad();
        tensor.getG().backward();
        return tensor.getGrad();
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }
}
