package org.tribuo.clustering.evaluation;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import org.apache.commons.math3.special.Gamma;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.evaluation.ClusteringMetric;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.util.infotheory.InformationTheory;
import org.tribuo.util.infotheory.impl.PairDistribution;

/* loaded from: input_file:org/tribuo/clustering/evaluation/ClusteringMetrics.class */
public enum ClusteringMetrics {
    NORMALIZED_MI((metricTarget, context) -> {
        return Double.valueOf(normalizedMI(context));
    }),
    ADJUSTED_MI((metricTarget2, context2) -> {
        return Double.valueOf(adjustedMI(context2));
    });

    private final BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl;

    ClusteringMetrics(BiFunction biFunction) {
        this.impl = biFunction;
    }

    public BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> getImpl() {
        return this.impl;
    }

    public ClusteringMetric forTarget(MetricTarget<ClusterID> metricTarget) {
        return new ClusteringMetric(metricTarget, name(), getImpl());
    }

    public static double adjustedMI(ClusteringMetric.Context context) {
        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
        double entropy = InformationTheory.entropy(context.getPredictedIDs());
        double entropy2 = InformationTheory.entropy(context.getTrueIDs());
        double expectedMI = expectedMI(context.getPredictedIDs(), context.getTrueIDs());
        return (mi - expectedMI) / (Math.min(entropy, entropy2) - expectedMI);
    }

    public static double normalizedMI(ClusteringMetric.Context context) {
        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
        double entropy = InformationTheory.entropy(context.getPredictedIDs());
        double entropy2 = InformationTheory.entropy(context.getTrueIDs());
        return entropy < entropy2 ? mi / entropy : mi / entropy2;
    }

    private static double expectedMI(List<Integer> list, List<Integer> list2) {
        PairDistribution constructFromLists = PairDistribution.constructFromLists(list, list2);
        Map map = constructFromLists.firstCount;
        Map map2 = constructFromLists.secondCount;
        long j = constructFromLists.count;
        double d = 0.0d;
        for (Map.Entry entry : map.entrySet()) {
            for (Map.Entry entry2 : map2.entrySet()) {
                long longValue = ((MutableLong) entry.getValue()).longValue();
                long longValue2 = ((MutableLong) entry2.getValue()).longValue();
                long min = Math.min(longValue, longValue2);
                long j2 = (longValue + longValue2) - j;
                long j3 = j2 > 1 ? j2 : 1L;
                while (true) {
                    long j4 = j3;
                    if (j4 < min) {
                        d += (j4 / j) * Math.log((j * j4) / (longValue * longValue2)) * Math.exp((((((((Gamma.logGamma(longValue + 1) + Gamma.logGamma(longValue2 + 1)) + Gamma.logGamma((j - longValue) + 1)) + Gamma.logGamma((j - longValue2) + 1)) - Gamma.logGamma(j + 1)) - Gamma.logGamma(j4 + 1)) - Gamma.logGamma((longValue - j4) + 1)) - Gamma.logGamma((longValue2 - j4) + 1)) - Gamma.logGamma((((j - longValue) - longValue2) + j4) + 1));
                        j3 = j4 + 1;
                    }
                }
            }
        }
        return d;
    }
}
