package tech.tablesaw.examples;

import it.unimi.dsi.fastutil.ints.IntIterator;
import tech.tablesaw.api.ColumnType;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.ml.classification.ConfusionMatrix;
import tech.tablesaw.api.ml.classification.LogisticRegression;
import tech.tablesaw.columns.Column;
import tech.tablesaw.io.csv.CsvReader;

/* loaded from: input_file:tech/tablesaw/examples/SfCrimeTest.class */
public class SfCrimeTest {
    public static void main(String[] strArr) throws Exception {
        Table createFromCsv = Table.createFromCsv("/Users/larrywhite/IdeaProjects/testdata/bigdata/train.csv");
        out(createFromCsv.shape());
        out(createFromCsv.structure().print());
        createFromCsv.removeColumns(new String[]{"DayOfWeek"});
        Column intColumn = createFromCsv.categoryColumn("PdDistrict").toIntColumn();
        intColumn.setName("Precinct");
        createFromCsv.addColumn(new Column[]{intColumn});
        Column year = createFromCsv.dateTimeColumn("Dates").year();
        year.setName("Year");
        createFromCsv.addColumn(new Column[]{year});
        out(createFromCsv.categoryColumn("Category").summary().sortDescendingOn(new String[]{"Count"}).print());
        Column minuteOfDay = createFromCsv.dateTimeColumn("Dates").minuteOfDay();
        minuteOfDay.setName("MinuteOfDay");
        createFromCsv.addColumn(new Column[]{minuteOfDay});
        Column dayOfYear = createFromCsv.dateTimeColumn("Dates").dayOfYear();
        dayOfYear.setName("DayOfYear");
        createFromCsv.addColumn(new Column[]{dayOfYear});
        Column dayOfWeekValue = createFromCsv.dateTimeColumn("Dates").dayOfWeekValue();
        dayOfWeekValue.setName("DayOfWeek");
        createFromCsv.addColumn(new Column[]{dayOfWeekValue});
        Table[] sampleSplit = createFromCsv.sampleSplit(0.1d);
        Table table = sampleSplit[0];
        Table table2 = sampleSplit[1];
        out(CsvReader.printColumnTypes("/Users/larrywhite/IdeaProjects/testdata/bigdata/sampleSubmission.csv", true, ','));
        LogisticRegression learn = LogisticRegression.learn(table.categoryColumn("Category"), 0.1d, 0.001d, 700, new NumericColumn[]{table.nCol("X"), table.nCol("Y"), table.nCol("MinuteOfDay"), table.nCol("DayOfYear"), table.nCol("DayOfWeek"), table.nCol("Year"), table.nCol("Precinct")});
        out("Model trained");
        ConfusionMatrix predictMatrix = learn.predictMatrix(table2.categoryColumn("Category"), new NumericColumn[]{table2.nCol("X"), table2.nCol("Y"), table2.nCol("MinuteOfDay"), table2.nCol("DayOfYear"), table2.nCol("DayOfWeek"), table2.nCol("Year"), table2.nCol("Precinct")});
        out(Double.valueOf(predictMatrix.accuracy()));
        out(predictMatrix.toTable().print());
        Table createFromCsv2 = Table.createFromCsv(new ColumnType[]{ColumnType.INTEGER, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT, ColumnType.FLOAT}, "/Users/larrywhite/IdeaProjects/testdata/bigdata/sampleSubmission.csv");
        FloatColumn floatColumn = createFromCsv2.floatColumn("LARCENY/THEFT");
        FloatColumn floatColumn2 = createFromCsv2.floatColumn("WARRANTS");
        IntIterator it = testData().iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            learn.predictFromModel(intValue, new double[39], new NumericColumn[]{table2.nCol("X"), table2.nCol("Y"), table2.nCol("MinuteOfDay"), table2.nCol("DayOfYear"), table2.nCol("DayOfWeek"), table2.nCol("Year"), table2.nCol("Precinct")});
            floatColumn.set(intValue, 1.0f);
            floatColumn2.set(intValue, 0.0f);
        }
        createFromCsv2.exportToCsv("newSubmission.csv");
    }

    private static Table testData() throws Exception {
        Table createFromCsv = Table.createFromCsv("/Users/larrywhite/IdeaProjects/testdata/bigdata/test.csv");
        createFromCsv.removeColumns(new String[]{"DayOfWeek"});
        Column intColumn = createFromCsv.categoryColumn("PdDistrict").toIntColumn();
        intColumn.setName("Precinct");
        createFromCsv.addColumn(new Column[]{intColumn});
        Column year = createFromCsv.dateTimeColumn("Dates").year();
        year.setName("Year");
        createFromCsv.addColumn(new Column[]{year});
        Column minuteOfDay = createFromCsv.dateTimeColumn("Dates").minuteOfDay();
        minuteOfDay.setName("MinuteOfDay");
        createFromCsv.addColumn(new Column[]{minuteOfDay});
        Column dayOfYear = createFromCsv.dateTimeColumn("Dates").dayOfYear();
        dayOfYear.setName("DayOfYear");
        createFromCsv.addColumn(new Column[]{dayOfYear});
        Column dayOfWeekValue = createFromCsv.dateTimeColumn("Dates").dayOfWeekValue();
        dayOfWeekValue.setName("DayOfWeek");
        createFromCsv.addColumn(new Column[]{dayOfWeekValue});
        return createFromCsv;
    }

    private static void out(Object obj) {
        System.out.println(String.valueOf(obj));
    }
}
