package weka.classifiers.misc;

import junit.framework.Test;
import junit.framework.TestSuite;
import junit.textui.TestRunner;
import weka.classifiers.AbstractClassifierTest;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.trees.J48;
import weka.core.Instances;
import weka.core.TestInstances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Reorder;
import weka.filters.unsupervised.attribute.SwapValues;

/* loaded from: input_file:weka/classifiers/misc/InputMappedClassifierTest.class */
public class InputMappedClassifierTest extends AbstractClassifierTest {
    public InputMappedClassifierTest(String str) {
        super(str);
    }

    @Override // weka.classifiers.AbstractClassifierTest
    public Classifier getClassifier() {
        InputMappedClassifier inputMappedClassifier = new InputMappedClassifier();
        inputMappedClassifier.setClassifier(new J48());
        inputMappedClassifier.setSuppressMappingReport(true);
        return inputMappedClassifier;
    }

    protected Instances reorderAtts(Instances instances) throws Exception {
        Reorder reorder = new Reorder();
        String str = "last";
        for (int numAttributes = instances.numAttributes() - 1; numAttributes > 0; numAttributes--) {
            str = str + "," + numAttributes;
        }
        reorder.setAttributeIndices(str);
        reorder.setInputFormat(instances);
        return Filter.useFilter(instances, reorder);
    }

    protected Instances swapValues(int i, Instances instances) throws Exception {
        SwapValues swapValues = new SwapValues();
        swapValues.setAttributeIndex("" + i);
        swapValues.setFirstValueIndex("first");
        swapValues.setSecondValueIndex("last");
        swapValues.setInputFormat(instances);
        return Filter.useFilter(instances, swapValues);
    }

    protected Instances generateData(boolean z, int i, int i2, int i3) throws Exception {
        TestInstances testInstances = new TestInstances();
        if (z) {
            testInstances.setClassType(1);
            testInstances.setNumClasses(i);
        } else {
            testInstances.setClassType(0);
        }
        testInstances.setNumNominal(i2);
        testInstances.setNumNumeric(i3);
        testInstances.setNumDate(0);
        testInstances.setNumString(0);
        testInstances.setNumRelational(0);
        testInstances.setNumInstances(100);
        testInstances.setClassIndex(-1);
        return testInstances.generate();
    }

    protected void performTest(boolean z, int i, int i2, boolean z2, boolean z3, boolean z4) {
        Instances instances = null;
        try {
            instances = generateData(z, i, i2, 3);
        } catch (Exception e) {
            fail("Generating training data failed: " + e);
        }
        Instances instances2 = new Instances(instances);
        if (z3) {
            try {
                instances2 = swapValues(1, instances2);
            } catch (Exception e2) {
                fail("Reordering nominal labels failed: " + e2);
            }
        }
        if (z4 && z) {
            try {
                instances2 = swapValues(7, instances2);
            } catch (Exception e3) {
                fail("Reordering class labels failed: " + e3);
            }
        }
        if (z2) {
            try {
                instances2 = reorderAtts(instances2);
            } catch (Exception e4) {
                fail("Reordering test data failed: " + e4);
            }
        }
        InputMappedClassifier inputMappedClassifier = null;
        try {
            inputMappedClassifier = trainClassifier(instances, z);
        } catch (Exception e5) {
            fail("Training classifier failed: " + e5);
        }
        double[] dArr = null;
        try {
            dArr = testClassifier(instances, inputMappedClassifier);
        } catch (Exception e6) {
            fail("Testing classifier on training data failed: " + e6);
        }
        double[] dArr2 = null;
        try {
            dArr2 = testClassifier(instances2, inputMappedClassifier);
        } catch (Exception e7) {
            fail("Testing classifier on test data failed: " + e7);
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            try {
                if (dArr[i3] != dArr2[i3]) {
                    throw new Exception("Result #" + (i3 + 1) + " differs!");
                }
            } catch (Exception e8) {
                fail("Comparing results failed " + e8);
                return;
            }
        }
    }

    public void testNominaClass() {
        performTest(true, 4, 3, false, false, false);
    }

    public void testNominaClassReorderedAtts() {
        performTest(true, 4, 3, true, false, false);
    }

    public void testNominalClassSwapNominalValues() {
        performTest(true, 4, 3, false, true, false);
    }

    public void testNominalClassSwapNominalValuesReorderAtts() {
        performTest(true, 4, 3, true, true, false);
    }

    public void testNominalClassSwapClassValues() {
        performTest(true, 4, 3, false, false, true);
    }

    public void testNominalClassSwapNominalValuesSwapClassValues() {
        performTest(true, 4, 3, false, true, true);
    }

    public void testNominalClassSwapNominalValuesSwapClassValuesReorderAtts() {
        performTest(true, 4, 3, true, true, true);
    }

    public void testNumericClass() {
        performTest(false, 4, 3, false, false, false);
    }

    public void testNumericClassReorderedAtts() {
        performTest(false, 4, 3, true, false, false);
    }

    public void testNumericClassSwapNominalValues() {
        performTest(false, 4, 3, false, true, false);
    }

    public void testNumericClassSwapNominalValuesReorderAtts() {
        performTest(false, 4, 3, true, true, false);
    }

    protected InputMappedClassifier trainClassifier(Instances instances, boolean z) {
        InputMappedClassifier inputMappedClassifier = new InputMappedClassifier();
        if (z) {
            inputMappedClassifier.setClassifier(new J48());
        } else {
            inputMappedClassifier.setClassifier(new LinearRegression());
        }
        inputMappedClassifier.setSuppressMappingReport(true);
        try {
            inputMappedClassifier.buildClassifier(instances);
            return inputMappedClassifier;
        } catch (Exception e) {
            fail("Training InputMappedClassifier failed: " + e);
            return null;
        }
    }

    protected double[] testClassifier(Instances instances, InputMappedClassifier inputMappedClassifier) {
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < instances.numInstances(); i++) {
            try {
                dArr[i] = inputMappedClassifier.classifyInstance(instances.instance(i));
            } catch (Exception e) {
                fail("Testing InputMappedClassifier failed: " + e);
                return null;
            }
        }
        return dArr;
    }

    public static Test suite() {
        return new TestSuite(InputMappedClassifierTest.class);
    }

    public static void main(String[] strArr) {
        TestRunner.run(suite());
    }
}
