package ml.dmlc.xgboost4j.java.spark.rapids;

import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVector;
import ai.rapids.cudf.Scalar;
import ai.rapids.cudf.Schema;
import ai.rapids.cudf.Table;
import java.util.List;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.BooleanType;
import org.apache.spark.sql.types.ByteType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DateType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.LongType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.ShortType;
import org.apache.spark.sql.types.StringType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/spark/rapids/GpuColumnVector.class */
public class GpuColumnVector extends ColumnVector {
    private final ai.rapids.cudf.ColumnVector cudfCv;
    private static final String BAD_ACCESS = "DATA ACCESS MUST BE ON A HOST VECTOR";
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ml.dmlc.xgboost4j.java.spark.rapids.GpuColumnVector$1, reason: invalid class name */
    /* loaded from: input_file:ml/dmlc/xgboost4j/java/spark/rapids/GpuColumnVector$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$rapids$cudf$DType = new int[DType.values().length];

        static {
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.BOOL8.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.INT8.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.INT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.INT32.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.INT64.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.FLOAT32.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.FLOAT64.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.TIMESTAMP_DAYS.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.TIMESTAMP_MICROSECONDS.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$ai$rapids$cudf$DType[DType.STRING.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
        }
    }

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/spark/rapids/GpuColumnVector$GpuColumnarBatchBuilder.class */
    public static final class GpuColumnarBatchBuilder implements AutoCloseable {
        private final HostColumnVector.Builder[] builders;
        private final StructField[] fields;

        public GpuColumnarBatchBuilder(StructType structType, int i, ColumnarBatch columnarBatch) {
            this.fields = structType.fields();
            int length = this.fields.length;
            this.builders = new HostColumnVector.Builder[length];
            boolean z = false;
            for (int i2 = 0; i2 < length; i2++) {
                try {
                    DType rapidsType = GpuColumnVector.getRapidsType(this.fields[i2]);
                    if (rapidsType == DType.STRING) {
                        int i3 = i * 8;
                        if (columnarBatch != null) {
                            WritableColumnVector column = columnarBatch.column(i2);
                            if (column instanceof WritableColumnVector) {
                                WritableColumnVector writableColumnVector = column;
                                if (!writableColumnVector.hasDictionary()) {
                                    i3 = writableColumnVector.getArrayOffset(i - 1) + writableColumnVector.getArrayLength(i - 1);
                                }
                            }
                        }
                        this.builders[i2] = HostColumnVector.builder(i, i3);
                    } else {
                        this.builders[i2] = HostColumnVector.builder(rapidsType, i);
                    }
                    z = true;
                } finally {
                    if (!z) {
                        for (HostColumnVector.Builder builder : this.builders) {
                            if (builder != null) {
                                builder.close();
                            }
                        }
                    }
                }
            }
        }

        public HostColumnVector.Builder builder(int i) {
            return this.builders[i];
        }

        public ColumnarBatch build(int i) {
            ColumnVector[] columnVectorArr = new ColumnVector[this.builders.length];
            boolean z = false;
            for (int i2 = 0; i2 < this.builders.length; i2++) {
                try {
                    columnVectorArr[i2] = new GpuColumnVector(this.fields[i2].dataType(), this.builders[i2].buildAndPutOnDevice());
                    this.builders[i2] = null;
                } catch (Throwable th) {
                    if (!z) {
                        for (ColumnVector columnVector : columnVectorArr) {
                            if (columnVector != null) {
                                columnVector.close();
                            }
                        }
                    }
                    throw th;
                }
            }
            ColumnarBatch columnarBatch = new ColumnarBatch(columnVectorArr, i);
            z = true;
            if (1 == 0) {
                for (ColumnVector columnVector2 : columnVectorArr) {
                    if (columnVector2 != null) {
                        columnVector2.close();
                    }
                }
            }
            return columnarBatch;
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            for (HostColumnVector.Builder builder : this.builders) {
                if (builder != null) {
                    builder.close();
                }
            }
        }
    }

    private static final DType toRapidsOrNull(DataType dataType) {
        if (dataType instanceof LongType) {
            return DType.INT64;
        }
        if (dataType instanceof DoubleType) {
            return DType.FLOAT64;
        }
        if (dataType instanceof ByteType) {
            return DType.INT8;
        }
        if (dataType instanceof BooleanType) {
            return DType.BOOL8;
        }
        if (dataType instanceof ShortType) {
            return DType.INT16;
        }
        if (dataType instanceof IntegerType) {
            return DType.INT32;
        }
        if (dataType instanceof FloatType) {
            return DType.FLOAT32;
        }
        if (dataType instanceof DateType) {
            return DType.TIMESTAMP_DAYS;
        }
        if (dataType instanceof TimestampType) {
            return DType.TIMESTAMP_MICROSECONDS;
        }
        if (dataType instanceof StringType) {
            return DType.STRING;
        }
        return null;
    }

    public static final boolean isSupportedType(DataType dataType) {
        return toRapidsOrNull(dataType) != null;
    }

    public static final DType getRapidsType(StructField structField) {
        return getRapidsType(structField.dataType());
    }

    public static final DType getRapidsType(DataType dataType) {
        DType rapidsOrNull = toRapidsOrNull(dataType);
        if (rapidsOrNull == null) {
            throw new IllegalArgumentException(dataType + " is not supported for GPU processing yet.");
        }
        return rapidsOrNull;
    }

    protected static final DataType getSparkType(DType dType) {
        switch (AnonymousClass1.$SwitchMap$ai$rapids$cudf$DType[dType.ordinal()]) {
            case 1:
                return DataTypes.BooleanType;
            case 2:
                return DataTypes.ByteType;
            case 3:
                return DataTypes.ShortType;
            case 4:
                return DataTypes.IntegerType;
            case 5:
                return DataTypes.LongType;
            case 6:
                return DataTypes.FloatType;
            case 7:
                return DataTypes.DoubleType;
            case 8:
                return DataTypes.DateType;
            case 9:
                return DataTypes.TimestampType;
            case 10:
                return DataTypes.StringType;
            default:
                throw new IllegalArgumentException(dType + " is not supported by spark yet.");
        }
    }

    public static final ColumnarBatch emptyBatch(StructType structType) {
        return new GpuColumnarBatchBuilder(structType, 0, null).build(0);
    }

    public static final ColumnarBatch emptyBatch(List<Attribute> list) {
        StructType structType = new StructType();
        for (Attribute attribute : list) {
            structType = structType.add(new StructField(attribute.name(), attribute.dataType(), attribute.nullable(), (Metadata) null));
        }
        return emptyBatch(structType);
    }

    public static final Schema from(StructType structType) {
        Schema.Builder builder = Schema.builder();
        structType.foreach(structField -> {
            return builder.column(getRapidsType(structField.dataType()), structField.name());
        });
        return builder.build();
    }

    public static final Table from(ColumnarBatch columnarBatch) {
        return new Table(extractBases(columnarBatch));
    }

    public static final ColumnarBatch from(Table table) {
        return from(table, 0, table.getNumberOfColumns());
    }

    public static final ColumnarBatch from(Table table, int i, int i2) {
        if (!$assertionsDisabled && table == null) {
            throw new AssertionError("Table cannot be null");
        }
        ColumnVector[] columnVectorArr = new ColumnVector[i2 - i];
        int i3 = 0;
        for (int i4 = i; i4 < i2; i4++) {
            try {
                columnVectorArr[i3] = from(table.getColumn(i4).incRefCount());
                i3++;
            } catch (Throwable th) {
                if (0 == 0) {
                    for (ColumnVector columnVector : columnVectorArr) {
                        if (columnVector != null) {
                            columnVector.close();
                        }
                    }
                }
                throw th;
            }
        }
        long rowCount = table.getRowCount();
        if (rowCount != ((int) rowCount)) {
            throw new IllegalStateException("Cannot support a batch larger that MAX INT rows");
        }
        ColumnarBatch columnarBatch = new ColumnarBatch(columnVectorArr, (int) rowCount);
        if (1 == 0) {
            for (ColumnVector columnVector2 : columnVectorArr) {
                if (columnVector2 != null) {
                    columnVector2.close();
                }
            }
        }
        return columnarBatch;
    }

    public static final GpuColumnVector from(ai.rapids.cudf.ColumnVector columnVector) {
        return new GpuColumnVector(getSparkType(columnVector.getType()), columnVector);
    }

    public static final GpuColumnVector from(Scalar scalar, int i) {
        return from(ai.rapids.cudf.ColumnVector.fromScalar(scalar, i));
    }

    public static final ai.rapids.cudf.ColumnVector[] extractBases(ColumnarBatch columnarBatch) {
        ai.rapids.cudf.ColumnVector[] columnVectorArr = new ai.rapids.cudf.ColumnVector[columnarBatch.numCols()];
        for (int i = 0; i < columnVectorArr.length; i++) {
            columnVectorArr[i] = ((GpuColumnVector) columnarBatch.column(i)).getBase();
        }
        return columnVectorArr;
    }

    public static final GpuColumnVector[] extractColumns(ColumnarBatch columnarBatch) {
        GpuColumnVector[] gpuColumnVectorArr = new GpuColumnVector[columnarBatch.numCols()];
        for (int i = 0; i < gpuColumnVectorArr.length; i++) {
            gpuColumnVectorArr[i] = (GpuColumnVector) columnarBatch.column(i);
        }
        return gpuColumnVectorArr;
    }

    public static final int[] toIntArray(ai.rapids.cudf.ColumnVector columnVector) {
        if (!$assertionsDisabled && columnVector.getType() != DType.INT32) {
            throw new AssertionError();
        }
        int rowCount = (int) columnVector.getRowCount();
        int[] iArr = new int[rowCount];
        HostColumnVector copyToHost = columnVector.copyToHost();
        Throwable th = null;
        try {
            for (int i = 0; i < rowCount; i++) {
                iArr[i] = copyToHost.getInt(i);
            }
            return iArr;
        } finally {
            if (copyToHost != null) {
                if (0 != 0) {
                    try {
                        copyToHost.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    copyToHost.close();
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GpuColumnVector(DataType dataType, ai.rapids.cudf.ColumnVector columnVector) {
        super(dataType);
        this.cudfCv = columnVector;
    }

    public final GpuColumnVector incRefCount() {
        this.cudfCv.incRefCount();
        return this;
    }

    public final void close() {
        this.cudfCv.close();
    }

    public final boolean hasNull() {
        return this.cudfCv.hasNulls();
    }

    public final int numNulls() {
        return (int) this.cudfCv.getNullCount();
    }

    public final boolean isNullAt(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final boolean getBoolean(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final byte getByte(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final short getShort(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final int getInt(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final long getLong(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final float getFloat(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final double getDouble(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final ColumnarArray getArray(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final ColumnarMap getMap(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final Decimal getDecimal(int i, int i2, int i3) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final UTF8String getUTF8String(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final byte[] getBinary(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public final ColumnVector getChild(int i) {
        throw new IllegalStateException(BAD_ACCESS);
    }

    public static final long getTotalDeviceMemoryUsed(ColumnarBatch columnarBatch) {
        long j = 0;
        for (int i = 0; i < columnarBatch.numCols(); i++) {
            j += ((GpuColumnVector) columnarBatch.column(i)).getBase().getDeviceMemorySize();
        }
        return j;
    }

    public static final long getTotalDeviceMemoryUsed(GpuColumnVector[] gpuColumnVectorArr) {
        long j = 0;
        for (GpuColumnVector gpuColumnVector : gpuColumnVectorArr) {
            j += gpuColumnVector.getBase().getDeviceMemorySize();
        }
        return j;
    }

    public static final long getTotalDeviceMemoryUsed(Table table) {
        long j = 0;
        int numberOfColumns = table.getNumberOfColumns();
        for (int i = 0; i < numberOfColumns; i++) {
            j += table.getColumn(i).getDeviceMemorySize();
        }
        return j;
    }

    public final ai.rapids.cudf.ColumnVector getBase() {
        return this.cudfCv;
    }

    public final long getRowCount() {
        return this.cudfCv.getRowCount();
    }

    public final RapidsHostColumnVector copyToHost() {
        return new RapidsHostColumnVector(this.type, this.cudfCv.copyToHost());
    }

    public final String toString() {
        return getBase().toString();
    }

    static {
        $assertionsDisabled = !GpuColumnVector.class.desiredAssertionStatus();
    }
}
