package org.linqs.psl.evaluation.statistics;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/RankingEvaluator.class */
public class RankingEvaluator extends Evaluator {
    public static final String CONFIG_PREFIX = "rankingevaluator";
    public static final String THRESHOLD_KEY = "rankingevaluator.threshold";
    public static final double DEFAULT_THRESHOLD = 0.5d;
    public static final String REPRESENTATIVE_KEY = "rankingevaluator.representative";
    public static final String DEFAULT_REPRESENTATIVE = "AUROC";
    private double threshold;
    private RepresentativeMetric representative;
    private List<GroundAtom> truth;
    private List<GroundAtom> predicted;

    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/RankingEvaluator$RepresentativeMetric.class */
    public enum RepresentativeMetric {
        AUROC,
        POSITIVE_AUPRC,
        NEGATIVE_AUPRC
    }

    public RankingEvaluator() {
        this(Config.getDouble(THRESHOLD_KEY, 0.5d), Config.getString(REPRESENTATIVE_KEY, DEFAULT_REPRESENTATIVE));
    }

    public RankingEvaluator(double d) {
        this(d, DEFAULT_REPRESENTATIVE);
    }

    public RankingEvaluator(double d, String str) {
        this(d, RepresentativeMetric.valueOf(str.toUpperCase()));
    }

    public RankingEvaluator(double d, RepresentativeMetric representativeMetric) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Threhsold must be in (0, 1). Found: " + d);
        }
        this.threshold = d;
        this.representative = representativeMetric;
        this.truth = new ArrayList();
        this.predicted = new ArrayList();
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap) {
        compute(trainingMap, null);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        this.truth = new ArrayList(trainingMap.getTrainingMap().size());
        this.predicted = new ArrayList(trainingMap.getTrainingMap().size());
        for (Map.Entry<GroundAtom, GroundAtom> entry : trainingMap.getFullMap()) {
            if (standardPredicate == null || entry.getKey().getPredicate() == standardPredicate) {
                this.truth.add(entry.getValue());
                this.predicted.add(entry.getKey());
            }
        }
        Collections.sort(this.truth);
        Collections.sort(this.predicted);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getRepresentativeMetric() {
        switch (this.representative) {
            case AUROC:
                return auroc();
            case POSITIVE_AUPRC:
                return positiveAUPRC();
            case NEGATIVE_AUPRC:
                return negativeAUPRC();
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public boolean isHigherRepresentativeBetter() {
        return true;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public double positiveAUPRC() {
        int i = 0;
        Iterator<GroundAtom> it = this.truth.iterator();
        while (it.hasNext()) {
            if (it.next().getValue() > this.threshold) {
                i++;
            }
        }
        if (i == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        int i2 = 0;
        int i3 = 0;
        double d2 = 1.0d;
        double d3 = 0.0d;
        Iterator<GroundAtom> it2 = this.predicted.iterator();
        while (it2.hasNext()) {
            Boolean label = getLabel(it2.next());
            if (label != null) {
                if (label == null || !label.booleanValue()) {
                    i3++;
                } else {
                    i2++;
                }
                double d4 = i2 / (i2 + i3);
                double d5 = i2 / i;
                d += 0.5d * (d5 - d3) * (d4 + d2);
                d2 = d4;
                d3 = d5;
            }
        }
        return d + (0.5d * (1.0d - d3) * (0.0d + d2));
    }

    public double negativeAUPRC() {
        int i = 0;
        Iterator<GroundAtom> it = this.truth.iterator();
        while (it.hasNext()) {
            if (it.next().getValue() > this.threshold) {
                i++;
            }
        }
        int size = this.predicted.size() - i;
        if (size == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        int i2 = i;
        int i3 = size;
        double d2 = i3 / (i3 + i2);
        double d3 = 1.0d;
        Iterator<GroundAtom> it2 = this.predicted.iterator();
        while (it2.hasNext()) {
            Boolean label = getLabel(it2.next());
            if (label != null) {
                if (label == null || !label.booleanValue()) {
                    i3--;
                } else {
                    i2--;
                }
                double d4 = 0.0d;
                if (i3 + i2 > 0) {
                    d4 = i3 / (i3 + i2);
                }
                double d5 = i3 / size;
                d += 0.5d * (d3 - d5) * (d2 + d4);
                d2 = d4;
                d3 = d5;
            }
        }
        return d;
    }

    public double auroc() {
        int i = 0;
        Iterator<GroundAtom> it = this.truth.iterator();
        while (it.hasNext()) {
            if (it.next().getValue() > this.threshold) {
                i++;
            }
        }
        int size = this.predicted.size() - i;
        double d = 0.0d;
        int i2 = 0;
        int i3 = 0;
        double d2 = 0.0d;
        double d3 = 0.0d;
        Iterator<GroundAtom> it2 = this.predicted.iterator();
        while (it2.hasNext()) {
            Boolean label = getLabel(it2.next());
            if (label != null) {
                if (label == null || !label.booleanValue()) {
                    i3++;
                } else {
                    i2++;
                }
                double d4 = i2 / i;
                double d5 = i3 / size;
                d += 0.5d * (d5 - d3) * (d4 + d2);
                d2 = d4;
                d3 = d5;
            }
        }
        return d + (0.5d * (1.0d - d3) * (1.0d + d2));
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public String getAllStats() {
        return String.format("AUROC: %f, Positive Class AUPRC: %f, Negative Class AUPRC: %f", Double.valueOf(auroc()), Double.valueOf(positiveAUPRC()), Double.valueOf(negativeAUPRC()));
    }

    private Boolean getLabel(GroundAtom groundAtom) {
        int indexOf = this.truth.indexOf(groundAtom);
        if (indexOf == -1) {
            return null;
        }
        return Boolean.valueOf(((double) this.truth.get(indexOf).getValue()) > this.threshold);
    }
}
