package tech.tablesaw.api.ml.classification;

import it.unimi.dsi.fastutil.ints.IntIterator;
import java.util.Collection;
import java.util.TreeSet;
import org.junit.Assert;
import org.junit.Test;
import smile.classification.KNN;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.QueryHelper;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import tech.tablesaw.util.DoubleArrays;

/* loaded from: input_file:tech/tablesaw/api/ml/classification/ConfusionMatrixTest.class */
public class ConfusionMatrixTest {
    @Test
    public void testAsTable() throws Exception {
        Table[] sampleSplit = Table.read().csv("../data/KNN_Example_1.csv").sampleSplit(0.5d);
        Table table = sampleSplit[0];
        Table table2 = sampleSplit[1];
        KNN learn = KNN.learn(DoubleArrays.to2dArray(table.nCol("X"), table.nCol("Y")), table.shortColumn(2).toIntArray(), 2);
        int[] iArr = new int[table2.rowCount()];
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet((Collection) table.shortColumn(2).asSet()));
        IntIterator it = table2.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            iArr[intValue] = learn.predict(new double[]{table2.floatColumn(0).getFloat(intValue), table2.floatColumn(1).getFloat(intValue)});
            standardConfusionMatrix.increment(Integer.valueOf(table2.shortColumn(2).get(intValue)), Integer.valueOf(iArr[intValue]));
        }
    }

    @Test
    public void testWithBooleanColumn() throws Exception {
        Table csv = Table.read().csv("../data/KNN_Example_1.csv");
        csv.addColumn(new Column[]{csv.selectIntoColumn("bt", QueryHelper.column("Label").isEqualTo(1))});
        Table[] sampleSplit = csv.sampleSplit(0.5d);
        Table table = sampleSplit[0];
        Table table2 = sampleSplit[1];
        LogisticRegression learn = LogisticRegression.learn(table.booleanColumn(3), new NumericColumn[]{table.nCol("X"), table.nCol("Y")});
        int[] iArr = new int[table2.rowCount()];
        StandardConfusionMatrix standardConfusionMatrix = new StandardConfusionMatrix(new TreeSet((Collection) table.shortColumn(2).asSet()));
        IntIterator it = table2.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            iArr[intValue] = learn.predict(new double[]{table2.floatColumn(0).getFloat(intValue), table2.floatColumn(1).getFloat(intValue)});
            standardConfusionMatrix.increment(Integer.valueOf(table2.shortColumn(2).get(intValue)), Integer.valueOf(iArr[intValue]));
        }
        Assert.assertNotNull(standardConfusionMatrix);
    }
}
