package weka.distributed;

import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import org.junit.Assert;
import org.junit.Test;
import weka.classifiers.bayes.NaiveBayesUpdateable;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.Utils;

/* loaded from: input_file:weka/distributed/WekaClassifierEvaluationTest.class */
public class WekaClassifierEvaluationTest {
    @Test
    public void testCrossValidateBatchMapOnly() throws Exception {
        Instances instances = new Instances(new BufferedReader(new StringReader(CorrelationMatrixMapTaskTest.IRIS)));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
        WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
        wekaClassifierMapTask.setClassifier(new J48());
        wekaClassifierMapTask.setTotalNumFolds(10);
        for (int i = 0; i < 10; i++) {
            wekaClassifierMapTask.setFoldNumber(i + 1);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            wekaClassifierMapTask.addToTrainingHeader(instances);
            wekaClassifierMapTask.finalizeTask();
            wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
            wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
            wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
            wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), new double[]{50.0d, 50.0d, 50.0d}, 150.0d, 1L, 0.0d);
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierEvaluationMapTask.processInstance(instances.instance(i2));
            }
            wekaClassifierEvaluationMapTask.finalizeTask();
            Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
            Assert.assertTrue(evaluation != null);
            Assert.assertEquals(15L, (int) evaluation.numInstances());
            Assert.assertTrue(Utils.isMissingValue(evaluation.areaUnderROC(0)));
        }
    }

    @Test
    public void testCrossValidateBatchMapOnlyRetainPredsForAUC() throws Exception {
        Instances instances = new Instances(new BufferedReader(new StringReader(CorrelationMatrixMapTaskTest.IRIS)));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
        WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
        wekaClassifierMapTask.setClassifier(new J48());
        wekaClassifierMapTask.setTotalNumFolds(10);
        for (int i = 0; i < 10; i++) {
            wekaClassifierMapTask.setFoldNumber(i + 1);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            wekaClassifierMapTask.addToTrainingHeader(instances);
            wekaClassifierMapTask.finalizeTask();
            wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
            wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
            wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
            wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), new double[]{50.0d, 50.0d, 50.0d}, 150.0d, 1L, 0.5d);
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierEvaluationMapTask.processInstance(instances.instance(i2));
            }
            wekaClassifierEvaluationMapTask.finalizeTask();
            Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
            Assert.assertTrue(evaluation != null);
            Assert.assertEquals(15L, (int) evaluation.numInstances());
            Assert.assertEquals(7L, evaluation.predictions().size());
            Assert.assertTrue(!Utils.isMissingValue(evaluation.areaUnderROC(0)));
        }
    }

    @Test
    public void testCrossValidateIncrementalMapOnly() throws Exception {
        Instances instances = new Instances(new BufferedReader(new StringReader(CorrelationMatrixMapTaskTest.IRIS)));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
        WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
        wekaClassifierMapTask.setClassifier(new NaiveBayesUpdateable());
        wekaClassifierMapTask.setTotalNumFolds(10);
        for (int i = 0; i < 10; i++) {
            wekaClassifierMapTask.setFoldNumber(i + 1);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierMapTask.processInstance(instances.instance(i2));
            }
            wekaClassifierMapTask.finalizeTask();
            wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
            wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
            wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
            wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), new double[]{50.0d, 50.0d, 50.0d}, 150.0d, 1L, 0.0d);
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                wekaClassifierEvaluationMapTask.processInstance(instances.instance(i3));
            }
            wekaClassifierEvaluationMapTask.finalizeTask();
            Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
            Assert.assertTrue(evaluation != null);
            Assert.assertEquals(15L, (int) evaluation.numInstances());
            Assert.assertTrue(Utils.isMissingValue(evaluation.areaUnderROC(0)));
        }
    }

    @Test
    public void testCrossValidateIncrementalMapOnlyRetainPredsForAUC() throws Exception {
        Instances instances = new Instances(new BufferedReader(new StringReader(CorrelationMatrixMapTaskTest.IRIS)));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
        WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
        wekaClassifierMapTask.setClassifier(new NaiveBayesUpdateable());
        wekaClassifierMapTask.setTotalNumFolds(10);
        for (int i = 0; i < 10; i++) {
            wekaClassifierMapTask.setFoldNumber(i + 1);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierMapTask.processInstance(instances.instance(i2));
            }
            wekaClassifierMapTask.finalizeTask();
            wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
            wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
            wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
            wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), new double[]{50.0d, 50.0d, 50.0d}, 150.0d, 1L, 1.0d);
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                wekaClassifierEvaluationMapTask.processInstance(instances.instance(i3));
            }
            wekaClassifierEvaluationMapTask.finalizeTask();
            Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
            Assert.assertTrue(evaluation != null);
            Assert.assertEquals(15L, (int) evaluation.numInstances());
            Assert.assertEquals(15L, evaluation.predictions().size());
            Assert.assertTrue(!Utils.isMissingValue(evaluation.areaUnderROC(0)));
        }
    }

    @Test
    public void testReduceOverFolds() throws Exception {
        Instances instances = new Instances(new BufferedReader(new StringReader(CorrelationMatrixMapTaskTest.IRIS)));
        instances.setClassIndex(instances.numAttributes() - 1);
        WekaClassifierEvaluationMapTask wekaClassifierEvaluationMapTask = new WekaClassifierEvaluationMapTask();
        WekaClassifierMapTask wekaClassifierMapTask = new WekaClassifierMapTask();
        wekaClassifierMapTask.setClassifier(new J48());
        wekaClassifierMapTask.setTotalNumFolds(10);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10; i++) {
            wekaClassifierMapTask.setFoldNumber(i + 1);
            wekaClassifierMapTask.setup(new Instances(instances, 0));
            wekaClassifierMapTask.addToTrainingHeader(instances);
            wekaClassifierMapTask.finalizeTask();
            wekaClassifierEvaluationMapTask.setClassifier(wekaClassifierMapTask.getClassifier());
            wekaClassifierEvaluationMapTask.setTotalNumFolds(10);
            wekaClassifierEvaluationMapTask.setFoldNumber(i + 1);
            wekaClassifierEvaluationMapTask.setup(new Instances(instances, 0), new double[]{50.0d, 50.0d, 50.0d}, 150.0d, 1L, 0.0d);
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                wekaClassifierEvaluationMapTask.processInstance(instances.instance(i2));
            }
            wekaClassifierEvaluationMapTask.finalizeTask();
            Evaluation evaluation = wekaClassifierEvaluationMapTask.getEvaluation();
            Assert.assertTrue(evaluation != null);
            Assert.assertEquals(15L, (int) evaluation.numInstances());
            Assert.assertTrue(Utils.isMissingValue(evaluation.areaUnderROC(0)));
            arrayList.add(evaluation);
        }
        Evaluation aggregate = new WekaClassifierEvaluationReduceTask().aggregate(arrayList);
        Assert.assertEquals(150L, (int) aggregate.numInstances());
        Assert.assertTrue(Utils.isMissingValue(aggregate.areaUnderROC(0)));
    }

    public static void main(String[] strArr) {
        try {
            WekaClassifierEvaluationTest wekaClassifierEvaluationTest = new WekaClassifierEvaluationTest();
            wekaClassifierEvaluationTest.testCrossValidateBatchMapOnly();
            wekaClassifierEvaluationTest.testCrossValidateBatchMapOnlyRetainPredsForAUC();
            wekaClassifierEvaluationTest.testCrossValidateIncrementalMapOnly();
            wekaClassifierEvaluationTest.testCrossValidateIncrementalMapOnlyRetainPredsForAUC();
            wekaClassifierEvaluationTest.testReduceOverFolds();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
