/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.evaluator.graph;

import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

public class GraphClassificationDataSetEvaluator
implements ModelEvaluator<ComputationGraph, DataSetIterator, Evaluation> {
    public Evaluation evaluateModel(ComputationGraph model, DataProvider<DataSetIterator> dataProvider) {
        DataSetIterator iterator = (DataSetIterator)dataProvider.testData(null);
        Evaluation eval = new Evaluation();
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            INDArray features = next.getFeatures();
            INDArray labels = next.getLabels();
            if (next.hasMaskArrays()) {
                INDArray[] iNDArrayArray;
                INDArray[] iNDArrayArray2;
                INDArray fMask = next.getFeaturesMaskArray();
                INDArray lMask = next.getLabelsMaskArray();
                if (fMask == null) {
                    iNDArrayArray2 = null;
                } else {
                    INDArray[] iNDArrayArray3 = new INDArray[1];
                    iNDArrayArray2 = iNDArrayArray3;
                    iNDArrayArray3[0] = fMask;
                }
                INDArray[] fMasks = iNDArrayArray2;
                if (lMask == null) {
                    iNDArrayArray = null;
                } else {
                    INDArray[] iNDArrayArray4 = new INDArray[1];
                    iNDArrayArray = iNDArrayArray4;
                    iNDArrayArray4[0] = lMask;
                }
                INDArray[] lMasks = iNDArrayArray;
                model.setLayerMaskArrays(fMasks, lMasks);
                INDArray out = model.output(new INDArray[]{next.getFeatures()})[0];
                if (lMask != null) {
                    eval.evalTimeSeries(next.getLabels(), out, lMask);
                } else {
                    eval.evalTimeSeries(next.getLabels(), out);
                }
                model.clearLayerMaskArrays();
                continue;
            }
            INDArray out = model.output(new INDArray[]{features})[0];
            if (out.rank() == 3) {
                eval.evalTimeSeries(labels, out);
                continue;
            }
            eval.eval(labels, out);
        }
        return eval;
    }
}

