/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.graph;

import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class GraphTestSetLossScoreFunctionDataSet
implements ScoreFunction<ComputationGraph, DataSetIterator> {
    private final boolean average;

    public GraphTestSetLossScoreFunctionDataSet() {
        this(false);
    }

    public GraphTestSetLossScoreFunctionDataSet(boolean average) {
        this.average = average;
    }

    public double score(ComputationGraph model, DataProvider<DataSetIterator> dataProvider, Map<String, Object> dataParameters) {
        DataSetIterator testData = (DataSetIterator)dataProvider.testData(dataParameters);
        double sumScore = 0.0;
        int totalExamples = 0;
        while (testData.hasNext()) {
            org.nd4j.linalg.dataset.DataSet ds = (org.nd4j.linalg.dataset.DataSet)testData.next();
            int numExamples = testData.numExamples();
            sumScore += (double)numExamples * model.score((DataSet)ds);
            totalExamples += numExamples;
        }
        if (!this.average) {
            return sumScore;
        }
        return sumScore / (double)totalExamples;
    }

    public boolean minimize() {
        return true;
    }

    public String toString() {
        return "GraphTestSetLossScoreFunctionDataSet()";
    }
}

