package com.gengoai.apollo.ml.evaluation;

import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.Split;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.collection.counter.MultiCounter;
import com.gengoai.collection.counter.MultiCounters;
import com.gengoai.conversion.Cast;
import com.gengoai.string.TableFormatter;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/evaluation/MultiClassEvaluation.class */
public class MultiClassEvaluation extends ClassifierEvaluation {
    private static final long serialVersionUID = 1;
    private final MultiCounter<String, String> confusionMatrix;
    private Encoder encoder;
    private double total;

    public static MultiClassEvaluation crossvalidation(DataSet dataSet, Model model, int i, String str) {
        MultiClassEvaluation multiClassEvaluation = new MultiClassEvaluation(str);
        for (Split split : Split.createFolds(dataSet.shuffle(), i)) {
            model.estimate(split.train);
            multiClassEvaluation.evaluate(model, split.test);
        }
        return multiClassEvaluation;
    }

    public static MultiClassEvaluation evaluate(@NonNull Model model, @NonNull DataSet dataSet, @NonNull String str) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("testingData is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputSourceName is marked non-null but is null");
        }
        MultiClassEvaluation multiClassEvaluation = new MultiClassEvaluation(str);
        multiClassEvaluation.evaluate(model, dataSet);
        return multiClassEvaluation;
    }

    public MultiClassEvaluation(@NonNull String str) {
        super(str);
        this.confusionMatrix = MultiCounters.newMultiCounter(new Map.Entry[0]);
        this.total = 0.0d;
        if (str == null) {
            throw new NullPointerException("outputName is marked non-null but is null");
        }
        this.encoder = null;
    }

    public MultiClassEvaluation(@NonNull String str, @NonNull Encoder encoder) {
        super(str);
        this.confusionMatrix = MultiCounters.newMultiCounter(new Map.Entry[0]);
        this.total = 0.0d;
        if (str == null) {
            throw new NullPointerException("outputName is marked non-null but is null");
        }
        if (encoder == null) {
            throw new NullPointerException("encoder is marked non-null but is null");
        }
        this.encoder = encoder;
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double accuracy() {
        return this.confusionMatrix.firstKeys().stream().mapToDouble(str -> {
            return this.confusionMatrix.get(str, str);
        }).sum() / this.total;
    }

    public double diagnosticOddsRatio(String str) {
        return positiveLikelihoodRatio(str) / negativeLikelihoodRatio(str);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void entry(double d, @NonNull NDArray nDArray) {
        String decode;
        String decode2;
        if (nDArray == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        if (this.encoder == null) {
            decode = Integer.toString((int) d);
            decode2 = Long.toString(nDArray.argmax());
        } else {
            decode = this.encoder.decode(d);
            decode2 = this.encoder.decode(nDArray.argmax());
        }
        this.confusionMatrix.increment(decode, decode2);
    }

    public void entry(@NonNull String str, @NonNull String str2) {
        if (str == null) {
            throw new NullPointerException("gold is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        this.confusionMatrix.increment(str, str2);
        this.total += 1.0d;
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void evaluate(@NonNull Model model, @NonNull DataSet dataSet) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            Observation observation = next.get(this.outputName);
            Observation observation2 = model.transform(next).get(this.outputName);
            if (observation2.isClassification()) {
                entry(observation.asVariable().getName(), observation2.asClassification().getResult());
            } else {
                entry(getIntegerLabelFor(observation, dataSet), observation2.asNDArray());
            }
        }
    }

    private double f1(double d, double d2) {
        if (d + d2 == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * d) * d2) / (d + d2);
    }

    public double f1(String str) {
        return f1(precision(str), recall(str));
    }

    public Counter<String> f1PerClass() {
        Counter<String> newCounter = Counters.newCounter(new String[0]);
        Counter<String> precisionPerClass = precisionPerClass();
        Counter<String> recallPerClass = recallPerClass();
        this.confusionMatrix.firstKeys().forEach(str -> {
            newCounter.set(str, f1(precisionPerClass.get(str), recallPerClass.get(str)));
        });
        return newCounter;
    }

    public double falseNegativeRate(String str) {
        double truePositives = truePositives(str);
        double falseNegatives = falseNegatives(str);
        if (truePositives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return falseNegatives / (truePositives + falseNegatives);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double falseNegatives() {
        return this.confusionMatrix.firstKeys().stream().mapToDouble(str -> {
            return this.confusionMatrix.get(str).sum() - this.confusionMatrix.get(str, str);
        }).sum();
    }

    public double falseNegatives(String str) {
        return this.confusionMatrix.get(str).sum() - this.confusionMatrix.get(str, str);
    }

    public double falsePositiveRate(String str) {
        double truePositives = truePositives(str);
        double falseNegatives = falseNegatives(str);
        if (truePositives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return falseNegatives / (truePositives + falseNegatives);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double falsePositives() {
        return this.confusionMatrix.firstKeys().stream().mapToDouble(str -> {
            double d = 0.0d;
            for (String str : this.confusionMatrix.firstKeys()) {
                if (!str.equals(str)) {
                    d += this.confusionMatrix.get(str, str);
                }
            }
            return d;
        }).sum();
    }

    public double falsePositives(String str) {
        double d = 0.0d;
        for (String str2 : this.confusionMatrix.firstKeys()) {
            if (!str2.equals(str)) {
                d += this.confusionMatrix.get(str2, str);
            }
        }
        return d;
    }

    public double macroF1() {
        return f1(macroPrecision(), macroRecall());
    }

    public double macroPrecision() {
        return precisionPerClass().average();
    }

    public double macroRecall() {
        return recallPerClass().average();
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void merge(ClassifierEvaluation classifierEvaluation) {
        Validation.checkArgument(classifierEvaluation instanceof MultiClassEvaluation, "Can only merge with other ClassifierEvaluation.");
        MultiClassEvaluation multiClassEvaluation = (MultiClassEvaluation) Cast.as(classifierEvaluation);
        this.confusionMatrix.merge(multiClassEvaluation.confusionMatrix);
        this.total += multiClassEvaluation.total;
    }

    public double microF1() {
        return f1(microPrecision(), microRecall());
    }

    public double microPrecision() {
        double truePositives = truePositives();
        double falsePositives = falsePositives();
        if (truePositives + falsePositives == 0.0d) {
            return 1.0d;
        }
        return truePositives / (truePositives + falsePositives);
    }

    public double microRecall() {
        double truePositives = truePositives();
        double falseNegatives = falseNegatives();
        if (truePositives + falseNegatives == 0.0d) {
            return 1.0d;
        }
        return truePositives / (truePositives + falseNegatives);
    }

    public double negativeLikelihoodRatio(String str) {
        return falseNegativeRate(str) / specificity(str);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void output(PrintStream printStream, boolean z) {
        Set set = (Set) this.confusionMatrix.entries().stream().flatMap(tuple3 -> {
            return Stream.of((Object[]) new String[]{(String) tuple3.v1, (String) tuple3.v2});
        }).collect(Collectors.toCollection(TreeSet::new));
        TreeSet treeSet = new TreeSet(this.confusionMatrix.firstKeys());
        TableFormatter tableFormatter = new TableFormatter();
        if (z) {
            tableFormatter.title("Confusion Matrix");
            tableFormatter.header(Collections.singleton(""));
            tableFormatter.header(set);
            tableFormatter.header(Collections.singleton("Total"));
            treeSet.forEach(str -> {
                ArrayList arrayList = new ArrayList();
                arrayList.add(str);
                set.forEach(str -> {
                    arrayList.add(Long.valueOf((long) this.confusionMatrix.get(str, str)));
                });
                arrayList.add(Long.valueOf((long) this.confusionMatrix.get(str).sum()));
                tableFormatter.content(arrayList);
            });
            ArrayList arrayList = new ArrayList();
            arrayList.add("Total");
            set.forEach(str2 -> {
                arrayList.add(Long.valueOf((long) this.confusionMatrix.firstKeys().stream().mapToDouble(str2 -> {
                    return this.confusionMatrix.get(str2, str2);
                }).sum()));
            });
            arrayList.add(Long.valueOf((long) this.confusionMatrix.sum()));
            tableFormatter.content(arrayList);
            tableFormatter.print(printStream);
            printStream.println();
        }
        tableFormatter.clear();
        tableFormatter.title("Classification Metrics").header(Arrays.asList("", "Precision", "Recall", "F1-Measure", "Correct", "Incorrect", "Missed", "Total"));
        treeSet.forEach(str3 -> {
            tableFormatter.content(Arrays.asList(str3, Double.valueOf(precision(str3)), Double.valueOf(recall(str3)), Double.valueOf(f1(str3)), Long.valueOf((long) truePositives(str3)), Long.valueOf((long) falsePositives(str3)), Long.valueOf((long) falseNegatives(str3)), Long.valueOf((long) this.confusionMatrix.get(str3).sum())));
        });
        tableFormatter.footer(Arrays.asList("micro", Double.valueOf(microPrecision()), Double.valueOf(microRecall()), Double.valueOf(microF1()), Long.valueOf((long) truePositives()), Long.valueOf((long) falsePositives()), Long.valueOf((long) falseNegatives()), Long.valueOf((long) this.total)));
        tableFormatter.footer(Arrays.asList("macro", Double.valueOf(macroPrecision()), Double.valueOf(macroRecall()), Double.valueOf(macroF1()), "-", "-", "-", "-"));
        tableFormatter.print(printStream);
    }

    public double positiveLikelihoodRatio(String str) {
        return truePositiveRate(str) / falsePositiveRate(str);
    }

    public double precision(String str) {
        double truePositives = truePositives(str);
        double falsePositives = falsePositives(str);
        if (truePositives + falsePositives == 0.0d) {
            return 1.0d;
        }
        return truePositives / (truePositives + falsePositives);
    }

    public Counter<String> precisionPerClass() {
        Counter<String> newCounter = Counters.newCounter(new String[0]);
        this.confusionMatrix.firstKeys().forEach(str -> {
            newCounter.set(str, precision(str));
        });
        return newCounter;
    }

    public double recall(String str) {
        double truePositives = truePositives(str);
        double falseNegatives = falseNegatives(str);
        if (truePositives + falseNegatives == 0.0d) {
            return 1.0d;
        }
        return truePositives / (truePositives + falseNegatives);
    }

    public Counter<String> recallPerClass() {
        Counter<String> newCounter = Counters.newCounter(new String[0]);
        this.confusionMatrix.firstKeys().forEach(str -> {
            newCounter.set(str, recall(str));
        });
        return newCounter;
    }

    public double sensitivity(String str) {
        return recall(str);
    }

    public double specificity(String str) {
        double trueNegatives = trueNegatives(str);
        double falsePositives = falsePositives(str);
        if (trueNegatives + falsePositives == 0.0d) {
            return 1.0d;
        }
        return trueNegatives / (trueNegatives + falsePositives);
    }

    public double trueNegativeRate(String str) {
        return specificity(str);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double trueNegatives() {
        return this.confusionMatrix.firstKeys().stream().mapToDouble(str -> {
            double d = 0.0d;
            for (String str : this.confusionMatrix.firstKeys()) {
                if (!str.equals(str)) {
                    d += this.confusionMatrix.get(str).sum() - this.confusionMatrix.get(str, str);
                }
            }
            return d;
        }).sum();
    }

    public double trueNegatives(String str) {
        double d = 0.0d;
        for (String str2 : this.confusionMatrix.firstKeys()) {
            if (!str2.equals(str)) {
                d += this.confusionMatrix.get(str2).sum() - this.confusionMatrix.get(str2, str);
            }
        }
        return d;
    }

    public double truePositiveRate(String str) {
        return recall(str);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double truePositives() {
        return this.confusionMatrix.firstKeys().stream().mapToDouble(str -> {
            return this.confusionMatrix.get(str, str);
        }).sum();
    }

    public double truePositives(String str) {
        return this.confusionMatrix.get(str, str);
    }
}
