package org.linqs.psl.evaluation.statistics;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Option;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/AUCEvaluator.class */
public class AUCEvaluator extends Evaluator {
    private double threshold;
    private RepresentativeMetric representative;
    private List<GroundAtom> truth;
    private List<GroundAtom> predicted;

    /* renamed from: org.linqs.psl.evaluation.statistics.AUCEvaluator$1, reason: invalid class name */
    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/AUCEvaluator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric = new int[RepresentativeMetric.values().length];

        static {
            try {
                $SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric[RepresentativeMetric.AUROC.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric[RepresentativeMetric.POSITIVE_AUPRC.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric[RepresentativeMetric.NEGATIVE_AUPRC.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/AUCEvaluator$RepresentativeMetric.class */
    public enum RepresentativeMetric {
        AUROC,
        POSITIVE_AUPRC,
        NEGATIVE_AUPRC
    }

    public AUCEvaluator() {
        this(Options.EVAL_AUC_THRESHOLD.getDouble());
    }

    public AUCEvaluator(double d) {
        this(d, Options.EVAL_AUC_REPRESENTATIVE.getString());
    }

    public AUCEvaluator(double d, String str) {
        this(d, RepresentativeMetric.valueOf(str.toUpperCase()));
    }

    public AUCEvaluator(double d, RepresentativeMetric representativeMetric) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Threhsold must be in (0, 1). Found: " + d);
        }
        this.threshold = d;
        this.representative = representativeMetric;
        this.truth = new ArrayList();
        this.predicted = new ArrayList();
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap) {
        compute(trainingMap, null);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        this.truth = new ArrayList(trainingMap.getLabelMap().size());
        this.predicted = new ArrayList(trainingMap.getLabelMap().size());
        for (Map.Entry<GroundAtom, GroundAtom> entry : getMap(trainingMap)) {
            if (standardPredicate == null || entry.getKey().getPredicate() == standardPredicate) {
                this.truth.add(entry.getValue());
                this.predicted.add(entry.getKey());
            }
        }
        Collections.sort(this.truth);
        Collections.sort(this.predicted);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getRepMetric() {
        switch (AnonymousClass1.$SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric[this.representative.ordinal()]) {
            case 1:
                return auroc();
            case Option.FLAG_POSITIVE /* 2 */:
                return positiveAUPRC();
            case 3:
                return negativeAUPRC();
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getBestRepScore() {
        switch (AnonymousClass1.$SwitchMap$org$linqs$psl$evaluation$statistics$AUCEvaluator$RepresentativeMetric[this.representative.ordinal()]) {
            case 1:
            case Option.FLAG_POSITIVE /* 2 */:
            case 3:
                return 1.0d;
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public boolean isHigherRepBetter() {
        return true;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public double positiveAUPRC() {
        return auprc(true);
    }

    public double negativeAUPRC() {
        return auprc(false);
    }

    private double auprc(boolean z) {
        int i = 0;
        Iterator<GroundAtom> it = this.truth.iterator();
        while (it.hasNext()) {
            if (!((((double) it.next().getValue()) >= this.threshold) ^ z)) {
                i++;
            }
        }
        if (i == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        int i2 = 0;
        int i3 = 0;
        double d2 = 1.0d;
        double d3 = 0.0d;
        Iterator<GroundAtom> it2 = this.predicted.iterator();
        while (it2.hasNext()) {
            Boolean label = getLabel(it2.next());
            if (label != null) {
                boolean booleanValue = label.booleanValue();
                if (!z) {
                    booleanValue = !booleanValue;
                }
                if (booleanValue) {
                    i2++;
                } else {
                    i3++;
                }
                double d4 = i2 / (i2 + i3);
                double d5 = i2 / i;
                d += ((d5 - d3) * Math.max(d2, d4)) - (0.5d * ((d5 - d3) * Math.abs(d4 - d2)));
                d2 = d4;
                d3 = d5;
            }
        }
        return d;
    }

    public double auroc() {
        int i = 0;
        Iterator<GroundAtom> it = this.truth.iterator();
        while (it.hasNext()) {
            if (it.next().getValue() > this.threshold) {
                i++;
            }
        }
        int size = this.predicted.size() - i;
        if (i == 0) {
            return 0.0d;
        }
        if (size == 0) {
            return 1.0d;
        }
        double d = 0.0d;
        int i2 = 0;
        int i3 = 0;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Iterator<GroundAtom> it2 = this.predicted.iterator();
        while (it2.hasNext()) {
            Boolean label = getLabel(it2.next());
            if (label != null) {
                if (label.booleanValue()) {
                    i2++;
                } else {
                    i3++;
                }
                double d4 = i2 / i;
                double d5 = i3 / size;
                d += (0.5d * (d5 - d3) * Math.abs(d4 - d2)) + ((d5 - d3) * d4);
                d2 = d4;
                d3 = d5;
            }
        }
        return d + (0.5d * (1.0d - d3) * Math.abs(1.0d - d2)) + ((1.0d - d3) * 1.0d);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public String getAllStats() {
        return String.format("AUROC: %f, Positive Class AUPRC: %f, Negative Class AUPRC: %f", Double.valueOf(auroc()), Double.valueOf(positiveAUPRC()), Double.valueOf(negativeAUPRC()));
    }

    private Boolean getLabel(GroundAtom groundAtom) {
        int indexOf = this.truth.indexOf(groundAtom);
        if (indexOf == -1) {
            return null;
        }
        return Boolean.valueOf(((double) this.truth.get(indexOf).getValue()) > this.threshold);
    }
}
