package org.tribuo.classification.sgd.objectives;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.classification.sgd.LabelObjective;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.util.NoopNormalizer;
import org.tribuo.math.util.VectorNormalizer;

/* loaded from: input_file:org/tribuo/classification/sgd/objectives/Hinge.class */
public class Hinge implements LabelObjective {

    @Config(description = "The classification margin.")
    private double margin;

    public Hinge(double d) {
        this.margin = 1.0d;
        this.margin = d;
    }

    public Hinge() {
        this(1.0d);
    }

    @Override // org.tribuo.classification.sgd.LabelObjective
    @Deprecated
    public Pair<Double, SGDVector> valueAndGradient(int i, SGDVector sGDVector) {
        return lossAndGradient(Integer.valueOf(i), sGDVector);
    }

    @Override // org.tribuo.classification.sgd.LabelObjective
    public Pair<Double, SGDVector> lossAndGradient(Integer num, SGDVector sGDVector) {
        sGDVector.add(num.intValue(), -this.margin);
        int indexOfMax = sGDVector.indexOfMax();
        if (num.intValue() == indexOfMax) {
            return new Pair<>(Double.valueOf(0.0d), SparseVector.createSparseVector(sGDVector.size(), new int[0], new double[0]));
        }
        int[] iArr = new int[2];
        double[] dArr = new double[2];
        if (num.intValue() < indexOfMax) {
            iArr[0] = num.intValue();
            dArr[0] = this.margin;
            iArr[1] = indexOfMax;
            dArr[1] = -this.margin;
        } else {
            iArr[0] = indexOfMax;
            dArr[0] = -this.margin;
            iArr[1] = num.intValue();
            dArr[1] = this.margin;
        }
        return new Pair<>(Double.valueOf(sGDVector.get(num.intValue()) - sGDVector.get(indexOfMax)), SparseVector.createSparseVector(sGDVector.size(), iArr, dArr));
    }

    @Override // org.tribuo.classification.sgd.LabelObjective
    public VectorNormalizer getNormalizer() {
        return new NoopNormalizer();
    }

    @Override // org.tribuo.classification.sgd.LabelObjective
    public boolean isProbabilistic() {
        return false;
    }

    public String toString() {
        return "Hinge(margin=" + this.margin + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m38getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "LabelObjective");
    }
}
