package com.gengoai.apollo.ml.evaluation;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Split;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.conversion.Cast;
import com.gengoai.math.Math2;
import com.gengoai.string.TableFormatter;
import java.io.PrintStream;
import java.util.Arrays;
import lombok.NonNull;
import org.apache.mahout.math.list.DoubleArrayList;

/* loaded from: input_file:com/gengoai/apollo/ml/evaluation/BinaryEvaluation.class */
public class BinaryEvaluation extends ClassifierEvaluation {
    private static final long serialVersionUID = 1;
    private final DoubleArrayList[] prob;
    private double fn;
    private double fp;
    private double negative;
    private double positive;
    private double tn;
    private double tp;

    public static BinaryEvaluation crossvalidation(DataSet dataSet, Model model, int i, String str) {
        BinaryEvaluation binaryEvaluation = new BinaryEvaluation(str);
        for (Split split : Split.createFolds(dataSet.shuffle(), i)) {
            model.estimate(split.train);
            binaryEvaluation.evaluate(model, split.test);
        }
        return binaryEvaluation;
    }

    public static BinaryEvaluation evaluate(@NonNull Model model, @NonNull DataSet dataSet, @NonNull String str) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("testingData is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputSourceName is marked non-null but is null");
        }
        BinaryEvaluation binaryEvaluation = new BinaryEvaluation(str);
        binaryEvaluation.evaluate(model, dataSet);
        return binaryEvaluation;
    }

    public BinaryEvaluation(@NonNull String str) {
        super(str);
        this.prob = new DoubleArrayList[]{new DoubleArrayList(), new DoubleArrayList()};
        this.fn = 0.0d;
        this.fp = 0.0d;
        this.negative = 0.0d;
        this.positive = 0.0d;
        this.tn = 0.0d;
        this.tp = 0.0d;
        if (str == null) {
            throw new NullPointerException("outputName is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double accuracy() {
        return (this.tp + this.tn) / (this.positive + this.negative);
    }

    public double auc() {
        return Math2.auc(this.prob[0].elements(), this.prob[1].elements());
    }

    public double baseline() {
        return Math.max(this.positive, this.negative) / (this.positive + this.negative);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void entry(double d, @NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        int i = (int) d;
        int argmax = (int) nDArray.argmax();
        this.prob[i].add(nDArray.get(serialVersionUID));
        if (i == 1) {
            this.positive += 1.0d;
            if (argmax == 1) {
                this.tp += 1.0d;
                return;
            } else {
                this.fn += 1.0d;
                return;
            }
        }
        this.negative += 1.0d;
        if (argmax == 1) {
            this.fp += 1.0d;
        } else {
            this.tn += 1.0d;
        }
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void evaluate(@NonNull Model model, @NonNull DataSet dataSet) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        dataSet.forEach(datum -> {
            entry(getIntegerLabelFor(datum.get(this.outputName), dataSet), model.transform(datum).get(this.outputName).asNDArray());
        });
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double falseNegatives() {
        return this.fn;
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double falsePositives() {
        return this.fp;
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void merge(ClassifierEvaluation classifierEvaluation) {
        if (!(classifierEvaluation instanceof BinaryEvaluation)) {
            throw new IllegalArgumentException();
        }
        BinaryEvaluation binaryEvaluation = (BinaryEvaluation) Cast.as(classifierEvaluation);
        this.prob[0].addAllOf(binaryEvaluation.prob[0]);
        this.prob[1].addAllOf(binaryEvaluation.prob[1]);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public void output(PrintStream printStream, boolean z) {
        TableFormatter tableFormatter = new TableFormatter();
        if (z) {
            tableFormatter.header(Arrays.asList("Predicted / Gold", "TRUE", "FALSE", "TOTAL"));
            tableFormatter.content(Arrays.asList("TRUE", Double.valueOf(truePositives()), Double.valueOf(falsePositives()), Double.valueOf(truePositives() + falsePositives())));
            tableFormatter.content(Arrays.asList("FALSE", Double.valueOf(falseNegatives()), Double.valueOf(trueNegatives()), Double.valueOf(falseNegatives() + trueNegatives())));
            tableFormatter.footer(Arrays.asList("", Double.valueOf(truePositives() + falseNegatives()), Double.valueOf(falsePositives() + trueNegatives()), Double.valueOf(this.positive + this.negative)));
            tableFormatter.print(printStream);
            tableFormatter = new TableFormatter();
        }
        tableFormatter.header(Arrays.asList("Metric", "Score"));
        tableFormatter.content(Arrays.asList("AUC", Double.valueOf(auc())));
        tableFormatter.content(Arrays.asList("Accuracy", Double.valueOf(accuracy())));
        tableFormatter.content(Arrays.asList("Baseline", Double.valueOf(baseline())));
        tableFormatter.content(Arrays.asList("TP Rate", Double.valueOf(truePositiveRate())));
        tableFormatter.content(Arrays.asList("FP Rate", Double.valueOf(falsePositiveRate())));
        tableFormatter.content(Arrays.asList("TN Rate", Double.valueOf(trueNegativeRate())));
        tableFormatter.content(Arrays.asList("FN Rate", Double.valueOf(falseNegativeRate())));
        tableFormatter.print(printStream);
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double trueNegatives() {
        return this.tn;
    }

    @Override // com.gengoai.apollo.ml.evaluation.ClassifierEvaluation
    public double truePositives() {
        return this.tp;
    }
}
