package ai.libs.jaicore.ml.classification.multilabel.evaluation.loss;

import ai.libs.jaicore.basic.ArrayUtil;
import java.util.Iterator;
import java.util.List;
import java.util.OptionalDouble;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multilabel/evaluation/loss/RankLoss.class */
public class RankLoss extends AMultiLabelClassificationMeasure {
    private static final double DEFAULT_TIE_LOSS = 0.0d;
    private final double tieLoss;

    public RankLoss() {
        this(DEFAULT_TIE_LOSS);
    }

    public RankLoss(double d) {
        this.tieLoss = d;
    }

    private double rankingLoss(int[] iArr, IMultiLabelClassification iMultiLabelClassification) {
        List argMax = ArrayUtil.argMax(iArr);
        List argMin = ArrayUtil.argMin(iArr);
        double[] prediction = iMultiLabelClassification.getPrediction();
        double d = 0.0d;
        Iterator it = argMax.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            Iterator it2 = argMin.iterator();
            while (it2.hasNext()) {
                int intValue2 = ((Integer) it2.next()).intValue();
                double d2 = prediction[intValue];
                double d3 = prediction[intValue2];
                if (d2 == d3) {
                    d += this.tieLoss;
                } else if (d2 < d3) {
                    d += 1.0d;
                }
            }
        }
        return d / (argMax.size() + argMin.size());
    }

    @Override // ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
    public double loss(List<? extends int[]> list, List<? extends IMultiLabelClassification> list2) {
        checkConsistency(list, list2);
        OptionalDouble average = IntStream.range(0, list.size()).mapToDouble(i -> {
            return rankingLoss((int[]) list.get(i), (IMultiLabelClassification) list2.get(i));
        }).average();
        if (average.isPresent()) {
            return average.getAsDouble();
        }
        throw new IllegalStateException("The ranking loss could not be averaged across all the instances.");
    }
}
