package org.campagnelab.dl.framework.performance;

import java.util.function.Consumer;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.prediction.BinaryClassPrediction;
import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/framework/performance/AUCHelper.class */
public class AUCHelper {
    public double estimate(DataSetIterator dataSetIterator, Model model, int i, Consumer<BinaryClassPrediction> consumer, Predicate<Integer> predicate, PredictionInterpreter predictionInterpreter) {
        if (model instanceof MultiLayerNetwork) {
            return estimateWithNet(dataSetIterator, (MultiLayerNetwork) model, i, consumer, predicate);
        }
        if (model instanceof ComputationGraph) {
            return estimateWithGraph(new MultiDataSetIteratorAdapter(dataSetIterator), (ComputationGraph) model, i, consumer, predicate, 0, predictionInterpreter);
        }
        throw new RuntimeException("model type not recognized.");
    }

    public double estimateWithNet(DataSetIterator dataSetIterator, MultiLayerNetwork multiLayerNetwork, int i, Consumer<BinaryClassPrediction> consumer, Predicate<Integer> predicate) {
        AreaUnderTheROCCurve areaUnderTheROCCurve = new AreaUnderTheROCCurve(i);
        int i2 = 0;
        int i3 = 0;
        BinaryClassPrediction binaryClassPrediction = new BinaryClassPrediction();
        while (dataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) dataSetIterator.next();
            INDArray output = multiLayerNetwork.output(dataSet.getFeatures());
            for (int i4 = 0; i4 < dataSet.numExamples(); i4++) {
                binaryClassPrediction.trueLabelYes = Double.valueOf(dataSet.getLabels().getDouble(i4, 1));
                binaryClassPrediction.predictedLabelNo = (float) output.getDouble(i4, 0);
                binaryClassPrediction.predictedLabelYes = (float) output.getDouble(i4, 1);
                areaUnderTheROCCurve.observe(binaryClassPrediction.predictedLabelYes, binaryClassPrediction.trueLabelYes.doubleValue() - 0.5d);
                int i5 = i2;
                i2++;
                binaryClassPrediction.index = i5;
                consumer.accept(binaryClassPrediction);
            }
            i3 += dataSet.numExamples();
            if (predicate.test(Integer.valueOf(i3))) {
                break;
            }
        }
        return areaUnderTheROCCurve.evaluateStatistic();
    }

    public double estimateWithGraph(MultiDataSetIterator multiDataSetIterator, ComputationGraph computationGraph, int i, Consumer<BinaryClassPrediction> consumer, Predicate<Integer> predicate, int i2, PredictionInterpreter predictionInterpreter) {
        AreaUnderTheROCCurve areaUnderTheROCCurve = new AreaUnderTheROCCurve(i);
        int i3 = 0;
        int i4 = 0;
        new BinaryClassPrediction();
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            INDArray[] output = computationGraph.output(multiDataSet.getFeatures());
            int size = multiDataSet.getFeatures(0).size(0);
            for (int i5 = 0; i5 < size; i5++) {
                BinaryClassPrediction binaryClassPrediction = (BinaryClassPrediction) predictionInterpreter.interpret(multiDataSet.getLabels(i2), output, i5);
                areaUnderTheROCCurve.observe(binaryClassPrediction.predictedLabelYes, binaryClassPrediction.trueLabelYes.doubleValue() - 0.5d);
                int i6 = i3;
                i3++;
                binaryClassPrediction.index = i6;
                consumer.accept(binaryClassPrediction);
            }
            i4 += size;
            if (predicate.test(Integer.valueOf(i4))) {
                break;
            }
        }
        return areaUnderTheROCCurve.evaluateStatistic();
    }
}
