package ai.libs.jaicore.ml.classification.multilabel.evaluation.loss.nonadditive;

import ai.libs.jaicore.ml.classification.multilabel.evaluation.loss.AMultiLabelClassificationMeasure;
import ai.libs.jaicore.ml.classification.multilabel.evaluation.loss.nonadditive.owa.IOWAValueFunction;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/multilabel/evaluation/loss/nonadditive/OWARelevanceLoss.class */
public class OWARelevanceLoss extends AMultiLabelClassificationMeasure {
    private final IOWAValueFunction valueFunction;

    public OWARelevanceLoss(IOWAValueFunction iOWAValueFunction) {
        this.valueFunction = iOWAValueFunction;
    }

    private double fcrit(int i, double d) {
        return 1.0d - Math.abs(d - i);
    }

    private double instanceLoss(int[] iArr, double[] dArr) {
        double d = 0.0d;
        double length = iArr.length;
        List list = (List) IntStream.range(0, iArr.length).mapToObj(i -> {
            return Double.valueOf(fcrit(iArr[i], dArr[i]));
        }).collect(Collectors.toList());
        list.add(Double.valueOf(0.0d));
        Collections.sort(list);
        for (int i2 = 1; i2 < list.size(); i2++) {
            d += (this.valueFunction.transform((length - i2) + 1.0d, length) - this.valueFunction.transform(length - i2, length)) * ((Double) list.get(i2)).doubleValue();
        }
        return 1.0d - d;
    }

    @Override // ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
    public double loss(List<? extends int[]> list, List<? extends IMultiLabelClassification> list2) {
        checkConsistency(list, list2);
        DescriptiveStatistics descriptiveStatistics = new DescriptiveStatistics();
        for (int i = 0; i < list.size(); i++) {
            descriptiveStatistics.addValue(instanceLoss(list.get(i), list2.get(i).getPrediction()));
        }
        return descriptiveStatistics.getMean();
    }

    @Override // ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
    public double score(List<? extends int[]> list, List<? extends IMultiLabelClassification> list2) {
        return 1.0d - loss(list, list2);
    }
}
