package org.deeplearning4j.arbiter.scoring.util;

import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIteratorFactory;

/* loaded from: input_file:org/deeplearning4j/arbiter/scoring/util/ScoreUtil.class */
public class ScoreUtil {
    public static MultiDataSetIterator getMultiIterator(Object obj) {
        if (obj instanceof MultiDataSetIterator) {
            return (MultiDataSetIterator) obj;
        }
        if (obj instanceof MultiDataSetIteratorFactory) {
            return ((MultiDataSetIteratorFactory) obj).create();
        }
        if (obj instanceof DataSetIterator) {
            return new MultiDataSetIteratorAdapter((DataSetIterator) obj);
        }
        if (obj instanceof DataSetIteratorFactory) {
            return new MultiDataSetIteratorAdapter(((DataSetIteratorFactory) obj).create());
        }
        throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory");
    }

    public static DataSetIterator getIterator(Object obj) {
        if (obj instanceof DataSetIterator) {
            return (DataSetIterator) obj;
        }
        if (obj instanceof DataSetIteratorFactory) {
            return ((DataSetIteratorFactory) obj).create();
        }
        throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory");
    }

    public static Evaluation getEvaluation(MultiLayerNetwork multiLayerNetwork, DataSetIterator dataSetIterator) {
        return multiLayerNetwork.evaluate(dataSetIterator);
    }

    public static Evaluation getEvaluation(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator) {
        if (computationGraph.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + computationGraph.getNumOutputArrays());
        }
        return computationGraph.evaluate(multiDataSetIterator);
    }

    public static Evaluation getEvaluation(ComputationGraph computationGraph, DataSetIterator dataSetIterator) {
        if (computationGraph.getNumOutputArrays() != 1) {
            throw new IllegalStateException("GraphSetSetAccuracyScoreFunctionDataSet cannot be applied to ComputationGraphs with more than one output. NumOutputs = " + computationGraph.getNumOutputArrays());
        }
        return computationGraph.evaluate(dataSetIterator);
    }

    public static double score(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, boolean z) {
        int i;
        double d = 0.0d;
        int i2 = 0;
        while (true) {
            i = i2;
            if (!multiDataSetIterator.hasNext()) {
                break;
            }
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            long size = multiDataSet.getFeatures(0).size(0);
            d += size * computationGraph.score(multiDataSet);
            i2 = (int) (i + size);
        }
        return !z ? d : d / i;
    }

    public static double score(ComputationGraph computationGraph, DataSetIterator dataSetIterator, boolean z) {
        int i;
        double d = 0.0d;
        int i2 = 0;
        while (true) {
            i = i2;
            if (!dataSetIterator.hasNext()) {
                break;
            }
            DataSet dataSet = (DataSet) dataSetIterator.next();
            int numExamples = dataSet.numExamples();
            d += numExamples * computationGraph.score(dataSet);
            i2 = i + numExamples;
        }
        return !z ? d : d / i;
    }

    public static double score(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, RegressionValue regressionValue) {
        RegressionEvaluation[] regressionEvaluationArr = new RegressionEvaluation[computationGraph.getNumOutputArrays()];
        for (int i = 0; i < regressionEvaluationArr.length; i++) {
            regressionEvaluationArr[i] = new RegressionEvaluation();
        }
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) multiDataSetIterator.next();
            INDArray[] labels = multiDataSet.getLabels();
            if (multiDataSet.hasMaskArrays()) {
                INDArray[] featuresMaskArrays = multiDataSet.getFeaturesMaskArrays();
                INDArray[] labelsMaskArrays = multiDataSet.getLabelsMaskArrays();
                computationGraph.setLayerMaskArrays(featuresMaskArrays, labelsMaskArrays);
                INDArray[] output = computationGraph.output(false, multiDataSet.getFeatures());
                for (int i2 = 0; i2 < regressionEvaluationArr.length; i2++) {
                    if (labelsMaskArrays == null || labelsMaskArrays[i2] == null) {
                        regressionEvaluationArr[i2].evalTimeSeries(labels[i2], output[i2]);
                    } else {
                        regressionEvaluationArr[i2].evalTimeSeries(labels[i2], output[i2], labelsMaskArrays[i2]);
                    }
                }
                computationGraph.clearLayerMaskArrays();
            } else {
                INDArray[] output2 = computationGraph.output(false, multiDataSet.getFeatures());
                for (int i3 = 0; i3 < regressionEvaluationArr.length; i3++) {
                    if (labels[i3].rank() == 3) {
                        regressionEvaluationArr[i3].evalTimeSeries(labels[i3], output2[i3]);
                    } else {
                        regressionEvaluationArr[i3].eval(labels[i3], output2[i3]);
                    }
                }
            }
        }
        double d = 0.0d;
        int i4 = 0;
        for (int i5 = 0; i5 < regressionEvaluationArr.length; i5++) {
            i4 += regressionEvaluationArr[i5].numColumns();
            d += getScoreFromRegressionEval(regressionEvaluationArr[i5], regressionValue);
        }
        if (regressionValue == RegressionValue.CorrCoeff) {
            d /= i4;
        }
        return d;
    }

    public static double score(ComputationGraph computationGraph, DataSetIterator dataSetIterator, RegressionValue regressionValue) {
        return getScoreFromRegressionEval(computationGraph.evaluateRegression(dataSetIterator), regressionValue);
    }

    public static double score(MultiLayerNetwork multiLayerNetwork, DataSetIterator dataSetIterator, boolean z) {
        int i;
        double d = 0.0d;
        int i2 = 0;
        while (true) {
            i = i2;
            if (!dataSetIterator.hasNext()) {
                break;
            }
            DataSet dataSet = (DataSet) dataSetIterator.next();
            int numExamples = dataSet.numExamples();
            d += numExamples * multiLayerNetwork.score(dataSet);
            i2 = i + numExamples;
        }
        return !z ? d : d / i;
    }

    public static double score(MultiLayerNetwork multiLayerNetwork, DataSetIterator dataSetIterator, RegressionValue regressionValue) {
        return getScoreFromRegressionEval(multiLayerNetwork.evaluateRegression(dataSetIterator), regressionValue);
    }

    @Deprecated
    public static double getScoreFromRegressionEval(RegressionEvaluation regressionEvaluation, RegressionValue regressionValue) {
        double d = 0.0d;
        int numColumns = regressionEvaluation.numColumns();
        switch (regressionValue) {
            case MSE:
                for (int i = 0; i < numColumns; i++) {
                    d += regressionEvaluation.meanSquaredError(i);
                }
                break;
            case MAE:
                for (int i2 = 0; i2 < numColumns; i2++) {
                    d += regressionEvaluation.meanAbsoluteError(i2);
                }
                break;
            case RMSE:
                for (int i3 = 0; i3 < numColumns; i3++) {
                    d += regressionEvaluation.rootMeanSquaredError(i3);
                }
                break;
            case RSE:
                for (int i4 = 0; i4 < numColumns; i4++) {
                    d += regressionEvaluation.relativeSquaredError(i4);
                }
                break;
            case CorrCoeff:
                for (int i5 = 0; i5 < numColumns; i5++) {
                    d += regressionEvaluation.correlationR2(i5);
                }
                d /= numColumns;
                break;
        }
        return d;
    }
}
