package org.tribuo.classification.evaluation;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelEvaluationUtil;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.provenance.EvaluationProvenance;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/tribuo/classification/evaluation/LabelEvaluationImpl.class */
public final class LabelEvaluationImpl implements LabelEvaluation {
    private final Map<MetricID<Label>, Double> results;
    private final LabelMetric.Context context;
    private final boolean modelGeneratesProbabilities;
    private final ConfusionMatrix<Label> cm;
    private final EvaluationProvenance provenance;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LabelEvaluationImpl(Map<MetricID<Label>, Double> map, LabelMetric.Context context, EvaluationProvenance evaluationProvenance) {
        this.results = map;
        this.context = context;
        this.provenance = evaluationProvenance;
        this.modelGeneratesProbabilities = context.getModel().generatesProbabilities();
        this.cm = context.getCM();
    }

    public List<Prediction<Label>> getPredictions() {
        return this.context.getPredictions();
    }

    public Map<MetricID<Label>, Double> asMap() {
        return Collections.unmodifiableMap(this.results);
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public double averagedPrecision(Label label) {
        if (this.modelGeneratesProbabilities) {
            return get(label, LabelMetrics.AVERAGED_PRECISION);
        }
        throw new UnsupportedOperationException("averaged precision score not available for models that do not generate probabilities");
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label) {
        return LabelMetrics.precisionRecallCurve(label, this.context.getPredictions());
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public double AUCROC(Label label) {
        if (this.modelGeneratesProbabilities) {
            return get(label, LabelMetrics.AUCROC);
        }
        throw new UnsupportedOperationException("AUCROC score not available for models that do not generate probabilities");
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public double averageAUCROC(boolean z) {
        if (!this.modelGeneratesProbabilities) {
            throw new UnsupportedOperationException("AUCROC score not available for models that do not generate probabilities");
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (Label label : this.cm.getDomain().getDomain()) {
            double d3 = get(label, LabelMetrics.AUCROC);
            double support = z ? this.cm.support(label) : 1.0d;
            d += support * d3;
            d2 += support;
        }
        return d / d2;
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double confusion(Label label, Label label2) {
        return this.cm.confusion(label, label2);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double tp(Label label) {
        return get(label, LabelMetrics.TP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double tp() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroTP() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double fp(Label label) {
        return get(label, LabelMetrics.FP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double fp() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroFP() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double tn(Label label) {
        return get(label, LabelMetrics.TN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double tn() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroTN() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double fn(Label label) {
        return get(label, LabelMetrics.FN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double fn() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroFN() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double precision(Label label) {
        return get(label, LabelMetrics.PRECISION);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double microAveragedPrecision() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroAveragedPrecision() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double recall(Label label) {
        return get(label, LabelMetrics.RECALL);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double microAveragedRecall() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroAveragedRecall() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double f1(Label label) {
        return get(label, LabelMetrics.F1);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double microAveragedF1() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double macroAveragedF1() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1);
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public double accuracy() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY);
    }

    @Override // org.tribuo.classification.evaluation.LabelEvaluation
    public double accuracy(Label label) {
        return get(label, LabelMetrics.ACCURACY);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public double balancedErrorRate() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE);
    }

    @Override // org.tribuo.classification.evaluation.ClassifierEvaluation
    public ConfusionMatrix<Label> getConfusionMatrix() {
        return this.cm;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public EvaluationProvenance m24getProvenance() {
        return this.provenance;
    }

    public String toString() {
        return LabelEvaluation.toFormattedString(this);
    }

    private double get(MetricTarget<Label> metricTarget, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(metricTarget).getID());
    }

    private double get(Label label, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(new MetricTarget<>(label)).getID());
    }

    private double get(EvaluationMetric.Average average, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(new MetricTarget<>(average)).getID());
    }
}
