package org.campagnelab.dl.framework.performance;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.campagnelab.dl.framework.domains.prediction.TimeSeriesPrediction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/framework/performance/TimeSeriesPerformanceCalculator.class */
public class TimeSeriesPerformanceCalculator {
    private Counter truePositives;
    private Counter falsePositives;
    private Counter falseNegatives;
    private Set<Integer> neverPredictedLabels;
    private Set<Integer> neverAppearedLabels;
    private int correctPredictions;
    private int totalPredictions;
    private Map<Integer, Double> labelPrecisions;
    private Map<Integer, Double> labelRecalls;
    private double mcPrecision;
    private double mcRecall;
    private double mcAccuracy;
    private double mcF1Score;
    private ConfusionMatrix confusionMatrix;
    private List<Integer> allLabels;
    private boolean evalCalled;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/campagnelab/dl/framework/performance/TimeSeriesPerformanceCalculator$ConfusionMatrix.class */
    public class ConfusionMatrix {
        private Map<Integer, Counter> backingMap = new HashMap();
        private List<Integer> allLabels;

        public ConfusionMatrix(List<Integer> list) {
            this.allLabels = list;
            Iterator<Integer> it = list.iterator();
            while (it.hasNext()) {
                this.backingMap.put(it.next(), new Counter());
            }
        }

        public void increment(int i, int i2) {
            this.backingMap.get(Integer.valueOf(i)).increment(i2);
        }

        public int count(int i, int i2) {
            return this.backingMap.get(Integer.valueOf(i)).count(i2);
        }

        public Map<Pair<Integer, Integer>, Integer> toMap() {
            HashMap hashMap = new HashMap();
            Iterator<Integer> it = this.allLabels.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                Iterator<Integer> it2 = this.allLabels.iterator();
                while (it2.hasNext()) {
                    int intValue2 = it2.next().intValue();
                    hashMap.put(new ImmutablePair(Integer.valueOf(intValue), Integer.valueOf(intValue2)), Integer.valueOf(count(intValue, intValue2)));
                }
            }
            return hashMap;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/campagnelab/dl/framework/performance/TimeSeriesPerformanceCalculator$Counter.class */
    public class Counter {
        private Map<Integer, Integer> backingMap = new HashMap();

        public Counter() {
        }

        public void increment(int i) {
            if (this.backingMap.containsKey(Integer.valueOf(i))) {
                this.backingMap.put(Integer.valueOf(i), Integer.valueOf(this.backingMap.get(Integer.valueOf(i)).intValue() + 1));
            } else {
                this.backingMap.put(Integer.valueOf(i), 1);
            }
        }

        public int count(int i) {
            Integer num = this.backingMap.get(Integer.valueOf(i));
            if (num != null) {
                return num.intValue();
            }
            return 0;
        }
    }

    public TimeSeriesPerformanceCalculator(List<Integer> list) {
        this.truePositives = new Counter();
        this.falsePositives = new Counter();
        this.falseNegatives = new Counter();
        this.neverPredictedLabels = new HashSet();
        this.neverAppearedLabels = new HashSet();
        this.labelPrecisions = new HashMap();
        this.labelRecalls = new HashMap();
        this.confusionMatrix = new ConfusionMatrix(list);
        this.totalPredictions = 0;
        this.allLabels = list;
    }

    public TimeSeriesPerformanceCalculator(int i) {
        this((List<Integer>) IntStream.range(0, i).boxed().collect(Collectors.toList()));
    }

    public TimeSeriesPerformanceCalculator addTimeSeries(TimeSeriesPrediction timeSeriesPrediction) {
        for (int i = 0; i < timeSeriesPrediction.trueLabels().length; i++) {
            this.totalPredictions++;
            this.confusionMatrix.increment(timeSeriesPrediction.trueLabels()[i], timeSeriesPrediction.predictedLabels()[i]);
            if (timeSeriesPrediction.trueLabels()[i] == timeSeriesPrediction.predictedLabels()[i]) {
                this.truePositives.increment(timeSeriesPrediction.trueLabels()[i]);
                this.correctPredictions++;
            } else {
                this.falseNegatives.increment(timeSeriesPrediction.trueLabels()[i]);
                this.falsePositives.increment(timeSeriesPrediction.predictedLabels()[i]);
            }
        }
        return this;
    }

    public TimeSeriesPerformanceCalculator eval() {
        for (Integer num : this.allLabels) {
            int count = this.truePositives.count(num.intValue());
            int count2 = this.falsePositives.count(num.intValue());
            int count3 = this.falseNegatives.count(num.intValue());
            if (count + count2 == 0) {
                this.neverPredictedLabels.add(num);
            } else {
                this.labelPrecisions.put(num, Double.valueOf(count / (count + count2)));
            }
            if (count + count3 == 0) {
                this.neverAppearedLabels.add(num);
            } else {
                this.labelRecalls.put(num, Double.valueOf(count / (count + count3)));
            }
        }
        double d = 0.0d;
        double d2 = 0.0d;
        Iterator<Double> it = this.labelPrecisions.values().iterator();
        while (it.hasNext()) {
            d += it.next().doubleValue();
        }
        Iterator<Double> it2 = this.labelRecalls.values().iterator();
        while (it2.hasNext()) {
            d2 += it2.next().doubleValue();
        }
        this.mcPrecision = d / this.labelPrecisions.size();
        this.mcRecall = d2 / this.labelRecalls.size();
        this.mcAccuracy = this.correctPredictions / this.totalPredictions;
        this.mcF1Score = ((2.0d * this.mcPrecision) * this.mcRecall) / (this.mcPrecision + this.mcRecall);
        this.evalCalled = true;
        return this;
    }

    public double getMetric(String str) {
        if (!$assertionsDisabled && !this.evalCalled) {
            throw new AssertionError("eval() should be called first");
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -2131707655:
                if (str.equals("accuracy")) {
                    z = true;
                    break;
                }
                break;
            case -1376177026:
                if (str.equals("precision")) {
                    z = 2;
                    break;
                }
                break;
            case -934922479:
                if (str.equals("recall")) {
                    z = 3;
                    break;
                }
                break;
            case 3211:
                if (str.equals("f1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.mcF1Score;
            case true:
                return this.mcAccuracy;
            case true:
                return this.mcPrecision;
            case true:
                return this.mcRecall;
            default:
                throw new UnsupportedOperationException("Unknown metric name");
        }
    }

    public Set<Integer> getNeverAppearedOrPredictedSet(String str) {
        if (!$assertionsDisabled && !this.evalCalled) {
            throw new AssertionError("eval() should be called first");
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -2061742777:
                if (str.equals("never_appeared")) {
                    z = false;
                    break;
                }
                break;
            case 1459811653:
                if (str.equals("never_predicted")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.neverAppearedLabels;
            case true:
                return this.neverPredictedLabels;
            default:
                throw new UnsupportedOperationException("Unknown set name");
        }
    }

    public Map<Pair<Integer, Integer>, Integer> getConfusionMatrix() {
        if ($assertionsDisabled || this.evalCalled) {
            return this.confusionMatrix.toMap();
        }
        throw new AssertionError("eval() should be called first");
    }

    public int countConfusionMatrix(int i, int i2) {
        return this.confusionMatrix.count(i, i2);
    }

    public String evalString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("\t%f", Double.valueOf(getMetric("accuracy"))));
        sb.append(String.format("\t%f", Double.valueOf(getMetric("precision"))));
        sb.append(String.format("\t%f", Double.valueOf(getMetric("recall"))));
        sb.append(String.format("\t%f", Double.valueOf(getMetric("f1"))));
        sb.append(String.format("\t%s", getNeverAppearedOrPredictedSet("never_appeared")));
        sb.append(String.format("\t%s", getNeverAppearedOrPredictedSet("never_predicted")));
        Iterator<Integer> it = this.allLabels.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Iterator<Integer> it2 = this.allLabels.iterator();
            while (it2.hasNext()) {
                sb.append(String.format("\t%d", Integer.valueOf(this.confusionMatrix.count(intValue, it2.next().intValue()))));
            }
        }
        return sb.toString();
    }

    public static double estimateFromGraph(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, int i, String str, int i2, long j) {
        TimeSeriesPerformanceCalculator timeSeriesPerformanceCalculator = new TimeSeriesPerformanceCalculator(i);
        TimeSeriesPrediction timeSeriesPrediction = new TimeSeriesPrediction();
        int i3 = 0;
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            if (multiDataSet.hasMaskArrays()) {
                computationGraph.setLayerMaskArrays(multiDataSet.getFeaturesMaskArrays(), multiDataSet.getLabelsMaskArrays());
            }
            INDArray iNDArray = computationGraph.output(multiDataSet.getFeatures())[i2];
            int size = multiDataSet.getFeatures(i2).size(0);
            for (int i4 = 0; i4 < size; i4++) {
                timeSeriesPrediction.setTrueLabels(iNDArray, i4).setPredictedLabels(multiDataSet.getLabels(i2), i4);
                timeSeriesPerformanceCalculator.addTimeSeries(timeSeriesPrediction);
                i3++;
            }
            if (i3 > j) {
                break;
            }
        }
        return timeSeriesPerformanceCalculator.eval().getMetric(str);
    }

    static {
        $assertionsDisabled = !TimeSeriesPerformanceCalculator.class.desiredAssertionStatus();
    }
}
