package weka.classifiers.meta;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import junit.framework.Test;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import weka.classifiers.AbstractClassifierTest;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.NominalPrediction;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.NoSupportForMissingValuesException;
import weka.core.SelectedTag;
import weka.core.UnsupportedAttributeTypeException;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.RemoveType;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/* loaded from: input_file:weka/classifiers/meta/ThresholdSelectorTest.class */
public class ThresholdSelectorTest extends AbstractClassifierTest {
    private static double[] DIST1 = {0.25d, 0.375d, 0.5d, 0.625d, 0.75d, 0.875d, 1.0d};
    protected transient Instances m_Instances;
    protected transient EvaluationUtils m_Evaluation;

    public ThresholdSelectorTest(String str) {
        super(str);
    }

    protected void setUp() throws Exception {
        super.setUp();
        this.m_Evaluation = new EvaluationUtils();
        this.m_Instances = new Instances(new BufferedReader(new InputStreamReader(ClassLoader.getSystemResourceAsStream("weka/classifiers/data/ClassifierTest.arff"))));
    }

    public Classifier getClassifier() {
        return getClassifier(DIST1);
    }

    protected void tearDown() {
        super.tearDown();
        this.m_Evaluation = null;
    }

    public Classifier getClassifier(double[] dArr) {
        return getClassifier((Classifier) new ThresholdSelectorDummyClassifier(dArr));
    }

    public Classifier getClassifier(Classifier classifier) {
        ThresholdSelector thresholdSelector = new ThresholdSelector();
        thresholdSelector.setClassifier(classifier);
        return thresholdSelector;
    }

    protected FastVector useClassifier() throws Exception {
        SelectedTag selectedTag;
        Classifier classifier = null;
        int numInstances = this.m_Instances.numInstances();
        int i = numInstances / 2;
        Instances instances = null;
        Instances instances2 = null;
        try {
            instances = new Instances(this.m_Instances, 0, i);
            instances2 = new Instances(this.m_Instances, i, numInstances - i);
            classifier = this.m_Classifier;
        } catch (Exception e) {
            e.printStackTrace();
            fail("Problem setting up to use classifier: " + e);
        }
        int i2 = 0;
        while (true) {
            try {
                return this.m_Evaluation.getTrainTestPredictions(classifier, instances, instances2);
            } catch (IllegalArgumentException e2) {
                if (e2.getMessage().indexOf("Not enough instances") == -1) {
                    throw e2;
                }
                System.err.println("\nInflating training data.");
                Instances instances3 = new Instances(instances);
                for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                    instances3.add(instances.instance(i3));
                }
                instances = instances3;
            } catch (NoSupportForMissingValuesException e3) {
                System.err.println("\nReplacing missing values.");
                ReplaceMissingValues replaceMissingValues = new ReplaceMissingValues();
                replaceMissingValues.setInputFormat(instances);
                instances = Filter.useFilter(instances, replaceMissingValues);
                replaceMissingValues.batchFinished();
                instances2 = Filter.useFilter(instances2, replaceMissingValues);
            } catch (UnsupportedAttributeTypeException e4) {
                boolean z = false;
                String message = e4.getMessage();
                if (message.indexOf("string") != -1 && message.indexOf("attributes") != -1) {
                    System.err.println("\nDeleting string attributes.");
                    selectedTag = new SelectedTag(2, RemoveType.TAGS_ATTRIBUTETYPE);
                } else if (message.indexOf("only") == -1 || message.indexOf("nominal") == -1) {
                    if (message.indexOf("only") == -1 || message.indexOf("numeric") == -1) {
                        throw e4;
                    }
                    System.err.println("\nDeleting non-numeric attributes.");
                    selectedTag = new SelectedTag(0, RemoveType.TAGS_ATTRIBUTETYPE);
                    z = true;
                } else {
                    System.err.println("\nDeleting non-nominal attributes.");
                    selectedTag = new SelectedTag(1, RemoveType.TAGS_ATTRIBUTETYPE);
                    z = true;
                }
                RemoveType removeType = new RemoveType();
                removeType.setAttributeType(selectedTag);
                removeType.setInvertSelection(z);
                removeType.setInputFormat(instances);
                instances = Filter.useFilter(instances, removeType);
                removeType.batchFinished();
                instances2 = Filter.useFilter(instances2, removeType);
                i2++;
                if (i2 > 2) {
                    throw e4;
                }
            }
        }
        throw e4;
    }

    public void testRangeNone() throws Exception {
        this.m_Classifier.setDesignatedClass(new SelectedTag(0, ThresholdSelector.TAGS_OPTIMIZE));
        this.m_Classifier.setRangeCorrection(new SelectedTag(0, ThresholdSelector.TAGS_RANGE));
        this.m_Instances.setClassIndex(1);
        FastVector useClassifier = useClassifier();
        assertTrue(useClassifier.size() != 0);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < useClassifier.size(); i++) {
            double d3 = ((NominalPrediction) useClassifier.elementAt(i)).distribution()[0];
            if (i == 0 || d3 < d) {
                d = d3;
            }
            if (i == 0 || d3 > d2) {
                d2 = d3;
            }
        }
        assertTrue("Upper limit shouldn't increase", d2 <= 1.0d);
        assertTrue("Lower limit shouldn'd decrease", d >= 0.25d);
    }

    public void testDesignatedClass() throws Exception {
        for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) {
            this.m_Classifier.setDesignatedClass(new SelectedTag(ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE));
            this.m_Instances.setClassIndex(1);
            assertTrue(useClassifier().size() != 0);
        }
    }

    public void testEvaluationMode() throws Exception {
        for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) {
            this.m_Classifier.setEvaluationMode(new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL));
            this.m_Instances.setClassIndex(1);
            assertTrue(useClassifier().size() != 0);
        }
    }

    public void testNumXValFolds() throws Exception {
        try {
            this.m_Classifier.setNumXValFolds(0);
            fail("Expected IllegalArgumentException");
        } catch (IllegalArgumentException e) {
        }
        for (int i = 2; i < 20; i += 2) {
            this.m_Classifier.setNumXValFolds(i);
            this.m_Instances.setClassIndex(1);
            assertTrue(useClassifier().size() != 0);
        }
    }

    public static Test suite() {
        return new TestSuite(ThresholdSelectorTest.class);
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }
}
