package org.linqs.psl.evaluation.statistics;

import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/DiscreteEvaluator.class */
public class DiscreteEvaluator extends Evaluator {
    public static final String CONFIG_PREFIX = "discreteevaluator";
    public static final String THRESHOLD_KEY = "discreteevaluator.threshold";
    public static final double DEFAULT_THRESHOLD = 0.5d;
    public static final String REPRESENTATIVE_KEY = "discreteevaluator.representative";
    public static final String DEFAULT_REPRESENTATIVE = "F1";
    private double threshold;
    private RepresentativeMetric representative;
    private int tp;
    private int fn;
    private int tn;
    private int fp;

    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/DiscreteEvaluator$RepresentativeMetric.class */
    public enum RepresentativeMetric {
        F1,
        POSITIVE_PRECISION,
        NEGATIVE_PRECISION,
        POSITIVE_RECALL,
        NEGATIVE_RECALL,
        ACCURACY
    }

    public DiscreteEvaluator() {
        this(Config.getDouble(THRESHOLD_KEY, 0.5d), Config.getString(REPRESENTATIVE_KEY, DEFAULT_REPRESENTATIVE));
    }

    public DiscreteEvaluator(double d) {
        this(d, DEFAULT_REPRESENTATIVE);
    }

    public DiscreteEvaluator(double d, String str) {
        this(d, RepresentativeMetric.valueOf(str.toUpperCase()));
    }

    public DiscreteEvaluator(double d, RepresentativeMetric representativeMetric) {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Threhsold must be in (0, 1). Found: " + d);
        }
        this.threshold = d;
        this.representative = representativeMetric;
        this.tp = 0;
        this.fn = 0;
        this.tn = 0;
        this.fp = 0;
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap) {
        compute(trainingMap, null);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        this.tp = 0;
        this.fn = 0;
        this.tn = 0;
        this.fp = 0;
        for (Map.Entry<GroundAtom, GroundAtom> entry : trainingMap.getFullMap()) {
            if (standardPredicate == null || entry.getKey().getPredicate() == standardPredicate) {
                boolean z = ((double) entry.getValue().getValue()) >= this.threshold;
                boolean z2 = ((double) entry.getKey().getValue()) >= this.threshold;
                if (z2 && z) {
                    this.tp++;
                } else if (!z2 && z) {
                    this.fn++;
                } else if (!z2 || z) {
                    this.tn++;
                } else {
                    this.fp++;
                }
            }
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getRepresentativeMetric() {
        switch (this.representative) {
            case F1:
                return f1();
            case POSITIVE_PRECISION:
                return positivePrecision();
            case NEGATIVE_PRECISION:
                return negativePrecision();
            case POSITIVE_RECALL:
                return positiveRecall();
            case NEGATIVE_RECALL:
                return negativeRecall();
            case ACCURACY:
                return accuracy();
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public boolean isHigherRepresentativeBetter() {
        return true;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public double positivePrecision() {
        if (this.tp + this.fp == 0) {
            return 0.0d;
        }
        return this.tp / (this.tp + this.fp);
    }

    public double negativePrecision() {
        if (this.tn + this.fn == 0) {
            return 0.0d;
        }
        return this.tn / (this.tn + this.fn);
    }

    public double positiveRecall() {
        if (this.tp + this.fn == 0) {
            return 0.0d;
        }
        return this.tp / (this.tp + this.fn);
    }

    public double negativeRecall() {
        if (this.tn + this.fp == 0) {
            return 0.0d;
        }
        return this.tn / (this.tn + this.fp);
    }

    public double f1() {
        return fScore(1.0d);
    }

    public double fScore(double d) {
        double positivePrecision = positivePrecision();
        double positiveRecall = positiveRecall();
        double pow = (Math.pow(d, 2.0d) * positivePrecision) + positiveRecall;
        if (MathUtils.isZero(pow)) {
            return 0.0d;
        }
        return ((1.0d + Math.pow(d, 2.0d)) * (positivePrecision * positiveRecall)) / pow;
    }

    public double accuracy() {
        int i = this.tp + this.tn + this.fp + this.fn;
        if (i == 0) {
            return 0.0d;
        }
        return (this.tp + this.tn) / i;
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public String getAllStats() {
        return String.format("Accuracy: %f, F1: %f, Positive Class Precision: %f, Positive Class Recall: %f, Negative Class Precision: %f, Negative Class Recall: %f", Double.valueOf(accuracy()), Double.valueOf(f1()), Double.valueOf(positivePrecision()), Double.valueOf(positiveRecall()), Double.valueOf(negativePrecision()), Double.valueOf(negativeRecall()));
    }
}
