package org.tribuo.classification.evaluation;

import java.util.List;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelEvaluationUtil;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.evaluation.metrics.MetricTarget;

/* loaded from: input_file:org/tribuo/classification/evaluation/LabelMetrics.class */
public enum LabelMetrics {
    TP((metricTarget, context) -> {
        return ConfusionMetrics.tp(metricTarget, context.getCM());
    }),
    FP((metricTarget2, context2) -> {
        return ConfusionMetrics.fp(metricTarget2, context2.getCM());
    }),
    TN((metricTarget3, context3) -> {
        return ConfusionMetrics.tn(metricTarget3, context3.getCM());
    }),
    FN((metricTarget4, context4) -> {
        return ConfusionMetrics.fn(metricTarget4, context4.getCM());
    }),
    PRECISION((metricTarget5, context5) -> {
        return ConfusionMetrics.precision(metricTarget5, context5.getCM());
    }),
    RECALL((metricTarget6, context6) -> {
        return ConfusionMetrics.recall(metricTarget6, context6.getCM());
    }),
    F1((metricTarget7, context7) -> {
        return ConfusionMetrics.f1(metricTarget7, context7.getCM());
    }),
    ACCURACY((metricTarget8, context8) -> {
        return ConfusionMetrics.accuracy(metricTarget8, context8.getCM());
    }),
    BALANCED_ERROR_RATE((metricTarget9, context9) -> {
        return ConfusionMetrics.balancedErrorRate(context9.getCM());
    }),
    AUCROC((metricTarget10, context10) -> {
        return AUCROC((MetricTarget<Label>) metricTarget10, (List<Prediction<Label>>) context10.getPredictions());
    }),
    AVERAGED_PRECISION((metricTarget11, context11) -> {
        return averagedPrecision((MetricTarget<Label>) metricTarget11, (List<Prediction<Label>>) context11.getPredictions());
    });

    private final ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> impl;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/classification/evaluation/LabelMetrics$PredictionProbabilities.class */
    public static final class PredictionProbabilities {
        final boolean[] ypos;
        final double[] yscore;

        PredictionProbabilities(Label label, List<Prediction<Label>> list) {
            int size = list.size();
            this.ypos = new boolean[size];
            this.yscore = new double[size];
            for (int i = 0; i < size; i++) {
                Prediction<Label> prediction = list.get(i);
                if (!prediction.hasProbabilities()) {
                    throw new UnsupportedOperationException(String.format("Invalid prediction at index %d: has no probability score.", Integer.valueOf(i)));
                }
                if (((Label) prediction.getExample().getOutput()).equals(label)) {
                    this.ypos[i] = true;
                }
                this.yscore[i] = ((Label) prediction.getOutputScores().get(label.getLabel())).getScore();
            }
        }
    }

    LabelMetrics(ToDoubleBiFunction toDoubleBiFunction) {
        this.impl = toDoubleBiFunction;
    }

    public ToDoubleBiFunction<MetricTarget<Label>, LabelMetric.Context> getImpl() {
        return this.impl;
    }

    public LabelMetric forTarget(MetricTarget<Label> metricTarget) {
        return new LabelMetric(metricTarget, name(), getImpl());
    }

    public static double averagedPrecision(MetricTarget<Label> metricTarget, List<Prediction<Label>> list) {
        if (metricTarget.getOutputTarget().isPresent()) {
            return averagedPrecision((Label) metricTarget.getOutputTarget().get(), list);
        }
        throw new IllegalStateException("Unsupported MetricTarget for averagedPrecision");
    }

    public static double averagedPrecision(Label label, List<Prediction<Label>> list) {
        PredictionProbabilities predictionProbabilities = new PredictionProbabilities(label, list);
        return LabelEvaluationUtil.averagedPrecision(predictionProbabilities.ypos, predictionProbabilities.yscore);
    }

    public static LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label, List<Prediction<Label>> list) {
        PredictionProbabilities predictionProbabilities = new PredictionProbabilities(label, list);
        return LabelEvaluationUtil.generatePRCurve(predictionProbabilities.ypos, predictionProbabilities.yscore);
    }

    public static double AUCROC(Label label, List<Prediction<Label>> list) {
        PredictionProbabilities predictionProbabilities = new PredictionProbabilities(label, list);
        return LabelEvaluationUtil.binaryAUCROC(predictionProbabilities.ypos, predictionProbabilities.yscore);
    }

    public static double AUCROC(MetricTarget<Label> metricTarget, List<Prediction<Label>> list) {
        if (metricTarget.getOutputTarget().isPresent()) {
            return AUCROC((Label) metricTarget.getOutputTarget().get(), list);
        }
        throw new IllegalStateException("Unsupported MetricTarget for AUCROC");
    }
}
