package com.gengoai.apollo.ml.evaluation;

import com.gengoai.LogUtils;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.io.resource.StringResource;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.logging.Level;
import java.util.logging.Logger;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/evaluation/ClassifierEvaluation.class */
public abstract class ClassifierEvaluation implements Evaluation, Serializable {
    private static final Logger log = Logger.getLogger(ClassifierEvaluation.class.getName());
    private static final long serialVersionUID = 1;
    protected final String outputName;

    /* JADX INFO: Access modifiers changed from: protected */
    public ClassifierEvaluation(String str) {
        this.outputName = str;
    }

    public abstract double accuracy();

    public double diagnosticOddsRatio() {
        return positiveLikelihoodRatio() / negativeLikelihoodRatio();
    }

    public abstract void entry(double d, @NonNull NDArray nDArray);

    public double falseNegativeRate() {
        double falseNegatives = falseNegatives();
        double truePositives = truePositives();
        if (truePositives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return falseNegatives / (falseNegatives + truePositives);
    }

    public abstract double falseNegatives();

    public double falseOmissionRate() {
        double falseNegatives = falseNegatives();
        double trueNegatives = trueNegatives();
        if (trueNegatives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return falseNegatives / (falseNegatives + trueNegatives);
    }

    public double falsePositiveRate() {
        double trueNegatives = trueNegatives();
        double falsePositives = falsePositives();
        if (trueNegatives + falsePositives == 0.0d) {
            return 0.0d;
        }
        return falsePositives / (trueNegatives + falsePositives);
    }

    public abstract double falsePositives();

    /* JADX INFO: Access modifiers changed from: protected */
    public int getIntegerLabelFor(Observation observation, DataSet dataSet) {
        if (observation.isNDArray() || observation.isClassification()) {
            NDArray asNDArray = observation.asNDArray();
            return asNDArray.shape().isScalar() ? (int) asNDArray.get(0L) : (int) asNDArray.argmax();
        }
        if (observation.isVariable()) {
            return dataSet.getMetadata(this.outputName).getEncoder().encode(observation.asVariable().getName());
        }
        throw new IllegalArgumentException("Unable to process output of type: " + observation.getClass());
    }

    public abstract void merge(ClassifierEvaluation classifierEvaluation);

    public double negativeLikelihoodRatio() {
        return falseNegativeRate() / specificity();
    }

    public double negativePredictiveValue() {
        double trueNegatives = trueNegatives();
        double falseNegatives = falseNegatives();
        if (trueNegatives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return trueNegatives / (trueNegatives + falseNegatives);
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void output(@NonNull PrintStream printStream) {
        if (printStream == null) {
            throw new NullPointerException("printStream is marked non-null but is null");
        }
        output(printStream, false);
    }

    public abstract void output(@NonNull PrintStream printStream, boolean z);

    public void output(boolean z) {
        output(System.out, z);
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void output() {
        output(System.out, false);
    }

    public double positiveLikelihoodRatio() {
        return truePositiveRate() / falsePositiveRate();
    }

    public void report(@NonNull Logger logger, @NonNull Level level, boolean z) {
        if (logger == null) {
            throw new NullPointerException("logger is marked non-null but is null");
        }
        if (level == null) {
            throw new NullPointerException("logLevel is marked non-null but is null");
        }
        StringResource stringResource = new StringResource();
        try {
            PrintStream printStream = new PrintStream(stringResource.outputStream());
            try {
                output(printStream, z);
                printStream.close();
                try {
                    LogUtils.log(logger, level, stringResource.readToString(), new Object[0]);
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            } finally {
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    public double sensitivity() {
        double truePositives = truePositives();
        double falseNegatives = falseNegatives();
        if (truePositives + falseNegatives == 0.0d) {
            return 0.0d;
        }
        return truePositives / (truePositives + falseNegatives);
    }

    public double specificity() {
        double trueNegatives = trueNegatives();
        double falsePositives = falsePositives();
        if (trueNegatives + falsePositives == 0.0d) {
            return 1.0d;
        }
        return trueNegatives / (trueNegatives + falsePositives);
    }

    public double trueNegativeRate() {
        return specificity();
    }

    public abstract double trueNegatives();

    public double truePositiveRate() {
        return sensitivity();
    }

    public abstract double truePositives();
}
