package org.deeplearning4j.arbiter.evaluator.multilayer;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator;
import org.deeplearning4j.arbiter.scoring.util.ScoreUtil;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/deeplearning4j/arbiter/evaluator/multilayer/ClassificationEvaluator.class */
public class ClassificationEvaluator implements ModelEvaluator {
    private Map<String, Object> params;

    /* renamed from: evaluateModel, reason: merged with bridge method [inline-methods] */
    public Evaluation m24evaluateModel(Object obj, DataProvider dataProvider) {
        if (obj instanceof MultiLayerNetwork) {
            return ScoreUtil.getEvaluation((MultiLayerNetwork) obj, ScoreUtil.getIterator(dataProvider.testData(this.params)));
        }
        return ScoreUtil.getEvaluation((ComputationGraph) obj, ScoreUtil.getIterator(dataProvider.testData(this.params)));
    }

    public List<Class<?>> getSupportedModelTypes() {
        return Arrays.asList(MultiLayerNetwork.class, ComputationGraph.class);
    }

    public List<Class<?>> getSupportedDataTypes() {
        return Arrays.asList(DataSetIterator.class, MultiDataSetIterator.class);
    }

    public ClassificationEvaluator() {
        this.params = null;
    }

    public ClassificationEvaluator(Map<String, Object> map) {
        this.params = null;
        this.params = map;
    }
}
