package ai.libs.jaicore.ml.ranking.loss;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.loss.IRankingPredictionPerformanceMeasure;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/loss/NDCGLoss.class */
public class NDCGLoss extends ARankingPredictionPerformanceMeasure implements IRankingPredictionPerformanceMeasure {
    private int l;

    public NDCGLoss(int i) {
        setL(i);
    }

    @Override // ai.libs.jaicore.ml.ranking.loss.ARankingPredictionPerformanceMeasure, ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
    public double loss(List<? extends IRanking<?>> list, List<? extends IRanking<?>> list2) {
        OptionalDouble average = IntStream.range(0, list.size()).mapToDouble(i -> {
            return loss((IRanking<?>) list.get(0), (IRanking<?>) list2.get(0));
        }).average();
        if (average.isPresent()) {
            return average.getAsDouble();
        }
        throw new IllegalStateException("Could not aggregate kendalls tau of top k");
    }

    @Override // ai.libs.jaicore.ml.ranking.loss.ARankingPredictionPerformanceMeasure
    public double loss(IRanking<?> iRanking, IRanking<?> iRanking2) {
        if (iRanking.size() <= 1) {
            throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
        }
        if (iRanking.size() != iRanking2.size()) {
            throw new IllegalArgumentException("Dyad rankings must have equal length.");
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.l; i++) {
            hashMap.put(iRanking.get(i), Integer.valueOf(-(i + 1)));
        }
        double computeDCG = computeDCG(iRanking2, hashMap);
        double computeDCG2 = computeDCG(iRanking, hashMap);
        if (computeDCG != 0.0d) {
            return computeDCG2 / computeDCG;
        }
        return 0.0d;
    }

    private double computeDCG(IRanking<?> iRanking, Map<Object, Integer> map) {
        int size = iRanking.size();
        double d = 0.0d;
        for (int i = 0; i < size; i++) {
            d += (Math.pow(2.0d, map.get(iRanking.get(i)).intValue()) - 1.0d) / log2(i + 2.0d);
        }
        return d;
    }

    private double log2(double d) {
        return Math.log(d) / Math.log(2.0d);
    }

    public int getL() {
        return this.l;
    }

    public void setL(int i) {
        this.l = i;
    }
}
