package ml.dmlc.xgboost4j.java.spark.rapids;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import java.util.List;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/spark/rapids/GpuColumnBatch.class */
public class GpuColumnBatch {
    private final Table table;
    private final StructType schema;

    public GpuColumnBatch(Table table, StructType structType) {
        this.table = table;
        this.schema = structType;
    }

    public StructType getSchema() {
        return this.schema;
    }

    public long getNumRows() {
        return this.table.getRowCount();
    }

    public int getNumColumns() {
        return this.table.getNumberOfColumns();
    }

    public ColumnVector getColumnVector(int i) {
        return this.table.getColumn(i);
    }

    public long getColumn(int i) {
        return this.table.getColumn(i).getNativeCudfColumnAddress();
    }

    public ColumnVector getColumnVectorInitHost(int i) {
        ColumnVector column = this.table.getColumn(i);
        column.ensureOnHost();
        return column;
    }

    private double getNumericValueInColumn(int i, ColumnVector columnVector, StructField structField) {
        double d;
        DataType dataType = structField.dataType();
        if (dataType instanceof FloatType) {
            d = columnVector.getFloat(i);
        } else if (dataType instanceof IntegerType) {
            d = columnVector.getInt(i);
        } else if (dataType instanceof ByteType) {
            d = columnVector.getByte(i);
        } else if (dataType instanceof ShortType) {
            d = columnVector.getShort(i);
        } else if (dataType instanceof DoubleType) {
            d = columnVector.getDouble(i);
        } else {
            if (!(dataType instanceof LongType)) {
                throw new IllegalArgumentException("Not a numeric type in column: " + structField.name());
            }
            d = columnVector.getLong(i);
        }
        return d;
    }

    private double getNumericValueInColumn(int i, int i2, double d) {
        ColumnVector columnVector = getColumnVector(i2);
        columnVector.ensureOnHost();
        return columnVector.getRowCount() > 0 ? getNumericValueInColumn(i, columnVector, getSchema().apply(i2)) : d;
    }

    public int getIntInColumn(int i, int i2, int i3) {
        return (int) getNumericValueInColumn(i, i2, i3);
    }

    public int groupAndAggregateOnColumnsHost(int i, int i2, int i3, List<Integer> list, List<Float> list2) {
        boolean z = i2 >= 0;
        ColumnVector columnVector = null;
        Float f = null;
        if (z) {
            columnVector = getColumnVectorInitHost(i2);
            f = list2.isEmpty() ? columnVector.getRowCount() > 0 ? Float.valueOf((float) getNumericValueInColumn(0, columnVector, getSchema().apply(i2))) : null : list2.get(list2.size() - 1);
        }
        ColumnVector columnVectorInitHost = getColumnVectorInitHost(i);
        StructField apply = getSchema().apply(i);
        int i4 = i3;
        int intValue = list.isEmpty() ? 0 : list.get(list.size() - 1).intValue();
        for (int i5 = 0; i5 < columnVectorInitHost.getRowCount(); i5++) {
            Float valueOf = Float.valueOf(z ? (float) getNumericValueInColumn(i5, columnVector, getSchema().apply(i2)) : 0.0f);
            int numericValueInColumn = (int) getNumericValueInColumn(i5, columnVectorInitHost, apply);
            if (numericValueInColumn == i4) {
                intValue++;
                if (z && !valueOf.equals(f)) {
                    throw new IllegalArgumentException("The instances in the same group have to be assigned with the same weight. Unexpected weight: " + valueOf);
                }
            } else {
                addOrUpdateInfos(i3, i4, intValue, f, z, list, list2);
                if (z) {
                    f = valueOf;
                }
                i4 = numericValueInColumn;
                intValue = 1;
            }
        }
        addOrUpdateInfos(i3, i4, intValue, f, z, list, list2);
        return i4;
    }

    private static void addOrUpdateInfos(int i, int i2, int i3, Float f, boolean z, List<Integer> list, List<Float> list2) {
        if (i3 <= 0) {
            return;
        }
        if (!list.isEmpty() && i2 == i) {
            list.set(list.size() - 1, Integer.valueOf(i3));
            return;
        }
        list.add(Integer.valueOf(i3));
        if (!z || f == null) {
            return;
        }
        list2.add(f);
    }

    public static DType getRapidsType(DataType dataType) {
        DType rapidsOrNull = toRapidsOrNull(dataType);
        if (rapidsOrNull == null) {
            throw new IllegalArgumentException(dataType + " is not supported for GPU processing yet.");
        }
        return rapidsOrNull;
    }

    private static DType toRapidsOrNull(DataType dataType) {
        if (dataType instanceof LongType) {
            return DType.INT64;
        }
        if (dataType instanceof DoubleType) {
            return DType.FLOAT64;
        }
        if (dataType instanceof ByteType) {
            return DType.INT8;
        }
        if (dataType instanceof BooleanType) {
            return DType.BOOL8;
        }
        if (dataType instanceof ShortType) {
            return DType.INT16;
        }
        if (dataType instanceof IntegerType) {
            return DType.INT32;
        }
        if (dataType instanceof FloatType) {
            return DType.FLOAT32;
        }
        if (dataType instanceof DateType) {
            return DType.DATE32;
        }
        if (dataType instanceof TimestampType) {
            return DType.TIMESTAMP;
        }
        if (dataType instanceof StringType) {
            return DType.STRING;
        }
        return null;
    }
}
