package ml.dmlc.xgboost4j.scala.spark.rapids;

import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.spark.rapids.GpuColumnBatch;
import ml.dmlc.xgboost4j.scala.ColumnDMatrix;
import ml.dmlc.xgboost4j.scala.DMatrix;
import org.apache.spark.RapidsUtils$;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.resource.ResourceInformation;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;

/* compiled from: GpuUtils.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/rapids/GpuUtils$.class */
public final class GpuUtils$ {
    public static GpuUtils$ MODULE$;

    static {
        new GpuUtils$();
    }

    public boolean isRapidsEnabled(Option<Dataset<?>> option) {
        return BoxesRunTime.unboxToBoolean(((Option) option.map(dataset -> {
            return new Some(dataset.sparkSession().sparkContext().getConf());
        }).getOrElse(() -> {
            return RapidsUtils$.MODULE$.getSparkContext().map(sparkContext -> {
                return sparkContext.getConf();
            });
        })).map(sparkConf -> {
            return BoxesRunTime.boxToBoolean($anonfun$isRapidsEnabled$4(sparkConf));
        }).getOrElse(() -> {
            return false;
        }));
    }

    public Option<Dataset<?>> isRapidsEnabled$default$1() {
        return None$.MODULE$;
    }

    public RDD<Table> toColumnarRdd(Dataset<Row> dataset) {
        return (RDD) Class.forName("com.nvidia.spark.rapids.ColumnarRdd").getDeclaredMethod("convert", Dataset.class).invoke(null, dataset);
    }

    public ColumnDataBatch buildColumnDataBatch(Seq<String> seq, String str, String str2, String str3, String str4, Dataset<Row> dataset) {
        Tuple2 tuple2;
        StructType schema = dataset.schema();
        Seq<String> seq2 = (Seq) seq.distinct();
        MLUtils$.MODULE$.validateSchema(schema, seq2, str, str2, str3, MLUtils$.MODULE$.validateSchema$default$6());
        if (str4.isEmpty()) {
            tuple2 = new Tuple2(None$.MODULE$, None$.MODULE$);
        } else {
            MLUtils$.MODULE$.checkNumericType(schema, str4, MLUtils$.MODULE$.checkNumericType$default$3());
            tuple2 = new Tuple2(new Some(str4), new Some(BoxesRunTime.boxToInteger(schema.fieldIndex(str4))));
        }
        Tuple2 tuple22 = tuple2;
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2 tuple23 = new Tuple2((Option) tuple22._1(), (Option) tuple22._2());
        Option option = (Option) tuple23._1();
        Option option2 = (Option) tuple23._2();
        Seq seq3 = (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{str2, str3})).map(str5 -> {
            return str5.isEmpty() ? None$.MODULE$ : new Some(BoxesRunTime.boxToInteger(schema.fieldIndex(str5)));
        }, Seq$.MODULE$.canBuildFrom());
        Some unapplySeq = Seq$.MODULE$.unapplySeq(seq3);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(2) != 0) {
            throw new MatchError(seq3);
        }
        Tuple2 tuple24 = new Tuple2((Option) ((SeqLike) unapplySeq.get()).apply(0), (Option) ((SeqLike) unapplySeq.get()).apply(1));
        return new ColumnDataBatch(dataset, new ColumnIndices((Seq) seq2.map(str6 -> {
            return BoxesRunTime.boxToInteger(schema.fieldIndex(str6));
        }, Seq$.MODULE$.canBuildFrom()), schema.fieldIndex(str), (Option) tuple24._1(), (Option) tuple24._2(), option2), option);
    }

    public Tuple2<DMatrix, ColumnBatchToRow> buildDMatrixAndColumnToRowIncrementally(float f, Iterator<GpuColumnBatch> iterator, Seq<Object> seq, StructType structType) {
        ColumnBatchToRow columnBatchToRow = new ColumnBatchToRow(structType);
        ColumnDMatrix columnDMatrix = null;
        while (iterator.hasNext()) {
            GpuColumnBatch gpuColumnBatch = (GpuColumnBatch) iterator.next();
            String arrayInterface = gpuColumnBatch.getArrayInterface((int[]) seq.toArray(ClassTag$.MODULE$.Int()));
            if (columnDMatrix == null) {
                columnDMatrix = new ColumnDMatrix(arrayInterface, f, 1);
            }
            columnBatchToRow.appendColumnBatch(gpuColumnBatch);
            gpuColumnBatch.close();
        }
        return new Tuple2<>(columnDMatrix, columnBatchToRow);
    }

    public Tuple2<DMatrix, ColumnBatchToRow> buildDMatrixAndColumnToRow(float f, Iterator<GpuColumnBatch> iterator, Seq<Object> seq, StructType structType) {
        if (iterator.isEmpty()) {
            return new Tuple2<>((Object) null, (Object) null);
        }
        GpuColumnBatch merge = GpuColumnBatch.merge((GpuColumnBatch[]) iterator.toArray(ClassTag$.MODULE$.apply(GpuColumnBatch.class)));
        ColumnBatchToRow appendColumnBatch = new ColumnBatchToRow(structType).appendColumnBatch(merge);
        ColumnDMatrix columnDMatrix = new ColumnDMatrix(merge.getArrayInterface((int[]) seq.toArray(ClassTag$.MODULE$.Int())), f, 1);
        merge.close();
        return new Tuple2<>(columnDMatrix, appendColumnBatch);
    }

    public int getGpuId(boolean z) {
        int i = 0;
        TaskContext taskContext = TaskContext$.MODULE$.get();
        if (!z) {
            ResourceInformation resourceInformation = (ResourceInformation) taskContext.resources().get("gpu").getOrElse(() -> {
                throw new RuntimeException("Spark could not allocate gpus for executor");
            });
            if (resourceInformation.addresses().length < 1) {
                throw new RuntimeException("executor could not get specific address of gpu");
            }
            i = new StringOps(Predef$.MODULE$.augmentString(resourceInformation.addresses()[0])).toInt();
        }
        return i;
    }

    public static final /* synthetic */ boolean $anonfun$isRapidsEnabled$4(SparkConf sparkConf) {
        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(sparkConf.get("spark.sql.extensions", "").split(","))).contains("com.nvidia.spark.rapids.SQLExecPlugin");
    }

    private GpuUtils$() {
        MODULE$ = this;
    }
}
