/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.graph;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class GraphTestSetRegressionScoreFunction
implements ScoreFunction<ComputationGraph, MultiDataSetIterator> {
    private final RegressionValue regressionValue;

    public GraphTestSetRegressionScoreFunction(RegressionValue regressionValue) {
        this.regressionValue = regressionValue;
    }

    public double score(ComputationGraph model, DataProvider<MultiDataSetIterator> dataProvider, Map<String, Object> dataParameters) {
        MultiDataSetIterator testSet = (MultiDataSetIterator)dataProvider.testData(dataParameters);
        int nOutputs = model.getNumOutputArrays();
        RegressionEvaluation[] evaluations = new RegressionEvaluation[nOutputs];
        for (int i = 0; i < evaluations.length; ++i) {
            evaluations[i] = new RegressionEvaluation(new String[0]);
        }
        while (testSet.hasNext()) {
            MultiDataSet next = (MultiDataSet)testSet.next();
            INDArray[] labels = next.getLabels();
            if (next.hasMaskArrays()) {
                INDArray[] fMasks = next.getFeaturesMaskArrays();
                INDArray[] lMasks = next.getLabelsMaskArrays();
                model.setLayerMaskArrays(fMasks, lMasks);
                INDArray[] outputs = model.output(false, next.getFeatures());
                for (int i = 0; i < evaluations.length; ++i) {
                    if (lMasks != null && lMasks[i] != null) {
                        evaluations[i].evalTimeSeries(labels[i], outputs[i], lMasks[i]);
                        continue;
                    }
                    evaluations[i].evalTimeSeries(labels[i], outputs[i]);
                }
                model.clearLayerMaskArrays();
                continue;
            }
            INDArray[] outputs = model.output(false, next.getFeatures());
            for (int i = 0; i < evaluations.length; ++i) {
                if (labels[i].rank() == 3) {
                    evaluations[i].evalTimeSeries(labels[i], outputs[i]);
                    continue;
                }
                evaluations[i].eval(labels[i], outputs[i]);
            }
        }
        double sum = 0.0;
        int totalColumns = 0;
        block11: for (int i = 0; i < evaluations.length; ++i) {
            int nColumns = evaluations[i].numColumns();
            totalColumns += nColumns;
            switch (this.regressionValue) {
                case MSE: {
                    int j;
                    for (j = 0; j < nColumns; ++j) {
                        sum += evaluations[i].meanSquaredError(j);
                    }
                    continue block11;
                }
                case MAE: {
                    int j;
                    for (j = 0; j < nColumns; ++j) {
                        sum += evaluations[i].meanAbsoluteError(j);
                    }
                    continue block11;
                }
                case RMSE: {
                    int j;
                    for (j = 0; j < nColumns; ++j) {
                        sum += evaluations[i].rootMeanSquaredError(j);
                    }
                    continue block11;
                }
                case RSE: {
                    int j;
                    for (j = 0; j < nColumns; ++j) {
                        sum += evaluations[i].relativeSquaredError(j);
                    }
                    continue block11;
                }
                case CorrCoeff: {
                    int j;
                    for (j = 0; j < nColumns; ++j) {
                        sum += evaluations[i].correlationR2(j);
                    }
                    continue block11;
                }
            }
        }
        if (this.regressionValue == RegressionValue.CorrCoeff) {
            sum /= (double)totalColumns;
        }
        return sum;
    }

    public boolean minimize() {
        return this.regressionValue != RegressionValue.CorrCoeff;
    }

    public String toString() {
        return "GraphTestSetRegressionScoreFunction(type=" + (Object)((Object)this.regressionValue) + ")";
    }
}

