/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.multilayer;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class TestSetRegressionScoreFunction
implements ScoreFunction<MultiLayerNetwork, DataSetIterator> {
    private final RegressionValue regressionValue;

    public TestSetRegressionScoreFunction(RegressionValue regressionValue) {
        this.regressionValue = regressionValue;
    }

    public double score(MultiLayerNetwork model, DataProvider<DataSetIterator> dataProvider, Map<String, Object> dataParameters) {
        DataSetIterator testSet = (DataSetIterator)dataProvider.testData(dataParameters);
        RegressionEvaluation eval = null;
        while (testSet.hasNext()) {
            INDArray out;
            DataSet next = (DataSet)testSet.next();
            if (eval == null) {
                eval = new RegressionEvaluation(next.getLabels().columns());
            }
            if ((out = next.hasMaskArrays() ? model.output(next.getFeatures(), false, next.getFeaturesMaskArray(), next.getLabelsMaskArray()) : model.output(next.getFeatures(), false)).rank() == 3) {
                if (next.getLabelsMaskArray() != null) {
                    eval.evalTimeSeries(next.getLabels(), out, next.getLabelsMaskArray());
                    continue;
                }
                eval.evalTimeSeries(next.getLabels(), out);
                continue;
            }
            eval.eval(next.getLabels(), out);
        }
        if (eval == null) {
            throw new IllegalStateException("test iterator is empty");
        }
        double sum = 0.0;
        int nColumns = eval.numColumns();
        switch (this.regressionValue) {
            case MSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.meanSquaredError(i);
                }
                break;
            }
            case MAE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.meanAbsoluteError(i);
                }
                break;
            }
            case RMSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.rootMeanSquaredError(i);
                }
                break;
            }
            case RSE: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.relativeSquaredError(i);
                }
                break;
            }
            case CorrCoeff: {
                for (int i = 0; i < nColumns; ++i) {
                    sum += eval.correlationR2(i);
                }
                sum /= (double)nColumns;
            }
        }
        return sum;
    }

    public boolean minimize() {
        return this.regressionValue != RegressionValue.CorrCoeff;
    }

    public String toString() {
        return "TestSetRegressionScoreFunction(type=" + (Object)((Object)this.regressionValue) + ")";
    }
}

