package org.tribuo.classification.sequence;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelMetric;
import org.tribuo.classification.evaluation.LabelMetrics;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricID;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.provenance.EvaluationProvenance;
import org.tribuo.sequence.SequenceEvaluation;

/* loaded from: input_file:org/tribuo/classification/sequence/LabelSequenceEvaluation.class */
public class LabelSequenceEvaluation implements SequenceEvaluation<Label> {
    private static final Logger logger = Logger.getLogger(LabelSequenceEvaluation.class.getName());
    private final Map<MetricID<Label>, Double> results;
    private final LabelMetric.Context ctx;
    private final ConfusionMatrix<Label> cm;
    private final EvaluationProvenance provenance;

    /* JADX INFO: Access modifiers changed from: protected */
    public LabelSequenceEvaluation(Map<MetricID<Label>, Double> map, LabelMetric.Context context, EvaluationProvenance evaluationProvenance) {
        this.results = map;
        this.ctx = context;
        this.cm = context.getCM();
        this.provenance = evaluationProvenance;
    }

    public List<Prediction<Label>> getPredictions() {
        return this.ctx.getPredictions();
    }

    public ConfusionMatrix<Label> getConfusionMatrix() {
        return this.cm;
    }

    public Map<MetricID<Label>, Double> asMap() {
        return Collections.unmodifiableMap(this.results);
    }

    public double confusion(Label label, Label label2) {
        return this.cm.confusion(label, label2);
    }

    public double tp(Label label) {
        return get(label, LabelMetrics.TP);
    }

    public double tp() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP);
    }

    public double macroTP() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP);
    }

    public double fp(Label label) {
        return get(label, LabelMetrics.FP);
    }

    public double fp() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP);
    }

    public double macroFP() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP);
    }

    public double tn(Label label) {
        return get(label, LabelMetrics.TN);
    }

    public double tn() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN);
    }

    public double macroTN() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN);
    }

    public double fn(Label label) {
        return get(label, LabelMetrics.FN);
    }

    public double fn() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN);
    }

    public double macroFN() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN);
    }

    public double precision(Label label) {
        return get(label, LabelMetrics.PRECISION);
    }

    public double microAveragedPrecision() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION);
    }

    public double macroAveragedPrecision() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION);
    }

    public double recall(Label label) {
        return get(label, LabelMetrics.RECALL);
    }

    public double microAveragedRecall() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL);
    }

    public double macroAveragedRecall() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL);
    }

    public double f1(Label label) {
        return get(label, LabelMetrics.RECALL);
    }

    public double microAveragedF1() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1);
    }

    public double macroAveragedF1() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1);
    }

    public double accuracy() {
        return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY);
    }

    public double accuracy(Label label) {
        return get(label, LabelMetrics.ACCURACY);
    }

    public double balancedErrorRate() {
        return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public EvaluationProvenance m30getProvenance() {
        return this.provenance;
    }

    public String toString() {
        ArrayList<Label> arrayList = new ArrayList(this.cm.getDomain().getDomain());
        StringBuilder sb = new StringBuilder();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int length = "Balanced Error Rate".length();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            length = Math.max(length, ((Label) it.next()).getLabel().length());
        }
        String format = String.format("%%-%ds", Integer.valueOf(length + 2));
        sb.append(String.format(format, "Class"));
        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
        for (Label label : arrayList) {
            if (this.cm.support(label) != 0.0d) {
                i4 = (int) (i4 + this.cm.support(label));
                i = (int) (i + this.cm.tp(label));
                i2 = (int) (i2 + this.cm.fn(label));
                i3 = (int) (i3 + this.cm.fp(label));
                sb.append(String.format(format, label));
                sb.append(String.format("%,12d%,12d%,12d%,12d", Integer.valueOf((int) this.cm.support(label)), Integer.valueOf((int) this.cm.tp(label)), Integer.valueOf((int) this.cm.fn(label)), Integer.valueOf((int) this.cm.fp(label))));
                sb.append(String.format("%12.3f%12.3f%12.3f%n", Double.valueOf(recall(label)), Double.valueOf(precision(label)), Double.valueOf(f1(label))));
            }
        }
        sb.append(String.format(format, "Total"));
        sb.append(String.format("%,12d%,12d%,12d%,12d%n", Integer.valueOf(i4), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3)));
        sb.append(String.format(format, "Accuracy"));
        sb.append(String.format("%60.3f%n", Double.valueOf(i / i4)));
        sb.append(String.format(format, "Micro Average"));
        sb.append(String.format("%60.3f%12.3f%12.3f%n", Double.valueOf(microAveragedRecall()), Double.valueOf(microAveragedPrecision()), Double.valueOf(microAveragedF1())));
        sb.append(String.format(format, "Macro Average"));
        sb.append(String.format("%60.3f%12.3f%12.3f%n", Double.valueOf(macroAveragedRecall()), Double.valueOf(macroAveragedPrecision()), Double.valueOf(macroAveragedF1())));
        sb.append(String.format(format, "Balanced Error Rate"));
        sb.append(String.format("%60.3f", Double.valueOf(balancedErrorRate())));
        return sb.toString();
    }

    private double get(MetricTarget<Label> metricTarget, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(metricTarget).getID());
    }

    private double get(Label label, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(new MetricTarget<>(label)).getID());
    }

    private double get(EvaluationMetric.Average average, LabelMetrics labelMetrics) {
        return get(labelMetrics.forTarget(new MetricTarget<>(average)).getID());
    }
}
