package ml.dmlc.xgboost4j.scala.spark.rapids;

import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import java.util.NoSuchElementException;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.spark.rapids.GpuColumnBatch;
import ml.dmlc.xgboost4j.java.spark.rapids.GpuColumnVector;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.ColumnDMatrix;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel$;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.param.Param;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.Function3;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.PartialFunction;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.BufferedIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Iterator$;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.Traversable;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.CanBuildFrom;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.collection.immutable.StringOps;
import scala.collection.immutable.Vector;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Nothing$;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

/* compiled from: GpuTransform.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/rapids/GpuTransform$.class */
public final class GpuTransform$ {
    public static GpuTransform$ MODULE$;
    private final Log ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger;

    static {
        new GpuTransform$();
    }

    public Log ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger() {
        return this.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger;
    }

    public Dataset<Row> transformInternal(XGBoostClassificationModel xGBoostClassificationModel, Dataset<?> dataset, Option<GpuSampler> option) {
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostClassificationModel, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostClassificationModel.leafPredictionCol(), xGBoostClassificationModel.contribPredictionCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(2) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple2 tuple2 = new Tuple2((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1));
        Seq apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{XGBoostClassificationModel$.MODULE$._rawPredictionCol(), XGBoostClassificationModel$.MODULE$._probabilityCol(), (String) tuple2._1(), (String) tuple2._2()}));
        return transformDistributed(dataset, xGBoostClassificationModel.nativeBooster(), xGBoostClassificationModel.getMissing(), xGBoostClassificationModel.getFeaturesCols(), xGBoostClassificationModel.getToRowCols(), (broadcast, dMatrix, iterator) -> {
            return MODULE$.buildRddIterator(apply, Predef$.MODULE$.wrapRefArray(xGBoostClassificationModel.producePredictionItrs(broadcast, dMatrix)), iterator);
        }, structType -> {
            return MODULE$.buildRddSchema(apply, structType);
        }, xGBoostClassificationModel.getBuildAllColumnsInTransform(), option);
    }

    public Dataset<Row> transformInternal(XGBoostRegressionModel xGBoostRegressionModel, Dataset<?> dataset, boolean z, Option<GpuSampler> option) {
        Dataset<?> dataset2;
        if (z) {
            dataset2 = MLUtils$.MODULE$.prepareColumnType(dataset, xGBoostRegressionModel.getFeaturesCols(), MLUtils$.MODULE$.prepareColumnType$default$3(), MLUtils$.MODULE$.prepareColumnType$default$4(), MLUtils$.MODULE$.prepareColumnType$default$5(), false);
        } else {
            dataset2 = dataset;
        }
        return transformInternal(xGBoostRegressionModel, dataset2, option);
    }

    private Dataset<Row> transformInternal(XGBoostRegressionModel xGBoostRegressionModel, Dataset<?> dataset, Option<GpuSampler> option) {
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostRegressionModel, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostRegressionModel.leafPredictionCol(), xGBoostRegressionModel.contribPredictionCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(2) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple2 tuple2 = new Tuple2((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1));
        Seq apply = Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new String[]{XGBoostRegressionModel$.MODULE$._originalPredictionCol(), (String) tuple2._1(), (String) tuple2._2()}));
        return transformDistributed(dataset, xGBoostRegressionModel.nativeBooster(), xGBoostRegressionModel.getMissing(), xGBoostRegressionModel.getFeaturesCols(), xGBoostRegressionModel.getToRowCols(), (broadcast, dMatrix, iterator) -> {
            return MODULE$.buildRddIterator(apply, Predef$.MODULE$.wrapRefArray(xGBoostRegressionModel.producePredictionItrs(broadcast, dMatrix)), iterator);
        }, structType -> {
            return MODULE$.buildRddSchema(apply, structType);
        }, xGBoostRegressionModel.getBuildAllColumnsInTransform(), option);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public StructType buildRddSchema(Seq<String> seq, StructType structType) {
        ObjectRef create = ObjectRef.create(structType);
        ArrayType arrayType = new ArrayType(FloatType$.MODULE$, false);
        seq.foreach(str -> {
            $anonfun$buildRddSchema$1(create, arrayType, str);
            return BoxedUnit.UNIT;
        });
        return (StructType) create.elem;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Iterator<Row> buildRddIterator(Seq<String> seq, Seq<Iterator<Row>> seq2, Iterator<Row> iterator) {
        Predef$.MODULE$.require(seq.length() == seq2.length());
        return (Iterator) ((TraversableOnce) seq.zip(seq2, Seq$.MODULE$.canBuildFrom())).foldLeft(iterator, (iterator2, tuple2) -> {
            Tuple2 tuple2 = new Tuple2(iterator2, tuple2);
            if (tuple2 != null) {
                Iterator iterator2 = (Iterator) tuple2._1();
                Tuple2 tuple22 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    String str = (String) tuple22._1();
                    Iterator iterator3 = (Iterator) tuple22._2();
                    return (new StringOps(Predef$.MODULE$.augmentString(str)).nonEmpty() && iterator3.nonEmpty()) ? iterator2.zip(iterator3).map(tuple23 -> {
                        if (tuple23 == null) {
                            throw new MatchError(tuple23);
                        }
                        return Row$.MODULE$.fromSeq((Seq) ((Row) tuple23._1()).toSeq().$plus$plus(((Row) tuple23._2()).toSeq(), Seq$.MODULE$.canBuildFrom()));
                    }) : iterator2;
                }
            }
            throw new MatchError(tuple2);
        });
    }

    private Dataset<Row> transformDistributed(Dataset<?> dataset, Booster booster, float f, Seq<String> seq, Seq<String> seq2, Function3<Broadcast<Booster>, DMatrix, Iterator<Row>, Iterator<Row>> function3, Function1<StructType, StructType> function1, boolean z, Option<GpuSampler> option) {
        ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger().info(new StringBuilder(43).append("Running GPU transfrom with buildAllColumns:").append(z).append(new StringBuilder(8).append(" column:").append(seq2).toString()).toString());
        StructType schema = dataset.schema();
        StructType buildRowSchema = ColumnBatchToRow$.MODULE$.buildRowSchema(schema, seq2, z);
        SparkContext sparkContext = dataset.sparkSession().sparkContext();
        Broadcast broadcast = sparkContext.broadcast(schema, ClassTag$.MODULE$.apply(StructType.class));
        Broadcast broadcast2 = sparkContext.broadcast(buildRowSchema, ClassTag$.MODULE$.apply(StructType.class));
        Broadcast broadcast3 = sparkContext.broadcast(booster, ClassTag$.MODULE$.apply(Booster.class));
        boolean isLocal = sparkContext.isLocal();
        Seq seq3 = (Seq) ((TraversableLike) seq.distinct()).map(str -> {
            return BoxesRunTime.boxToInteger(schema.fieldIndex(str));
        }, Seq$.MODULE$.canBuildFrom());
        RDD<Table> columnarRdd = GpuUtils$.MODULE$.toColumnarRdd(dataset);
        RDD mapPartitions = columnarRdd.mapPartitions(iterator -> {
            Iterator iterator;
            if (z && option.isEmpty()) {
                final UnsafeProjection create = UnsafeProjection$.MODULE$.create((StructType) broadcast.value());
                return new Iterator<Row>(iterator, broadcast, create, seq3, f, isLocal, broadcast3, function3) { // from class: ml.dmlc.xgboost4j.scala.spark.rapids.GpuTransform$$anon$1
                    private int batchCnt;
                    private RowConverter converter;
                    private transient ColumnarBatch cb;
                    private Iterator<Row> it;
                    private final Iterator iter$1;
                    private final Broadcast bOrigSchema$1;
                    private final UnsafeProjection toUnsafe$1;
                    private final Seq featureIds$1;
                    private final float missing$1;
                    private final boolean isLocal$1;
                    private final Broadcast bBooster$1;
                    private final Function3 predictFunc$1;

                    /* renamed from: seq, reason: merged with bridge method [inline-methods] */
                    public Iterator<Row> m92seq() {
                        return Iterator.seq$(this);
                    }

                    public boolean isEmpty() {
                        return Iterator.isEmpty$(this);
                    }

                    public boolean isTraversableAgain() {
                        return Iterator.isTraversableAgain$(this);
                    }

                    public boolean hasDefiniteSize() {
                        return Iterator.hasDefiniteSize$(this);
                    }

                    public Iterator<Row> take(int i) {
                        return Iterator.take$(this, i);
                    }

                    public Iterator<Row> drop(int i) {
                        return Iterator.drop$(this, i);
                    }

                    public Iterator<Row> slice(int i, int i2) {
                        return Iterator.slice$(this, i, i2);
                    }

                    public Iterator<Row> sliceIterator(int i, int i2) {
                        return Iterator.sliceIterator$(this, i, i2);
                    }

                    public <B> Iterator<B> map(Function1<Row, B> function12) {
                        return Iterator.map$(this, function12);
                    }

                    public <B> Iterator<B> $plus$plus(Function0<GenTraversableOnce<B>> function0) {
                        return Iterator.$plus$plus$(this, function0);
                    }

                    public <B> Iterator<B> flatMap(Function1<Row, GenTraversableOnce<B>> function12) {
                        return Iterator.flatMap$(this, function12);
                    }

                    public Iterator<Row> filter(Function1<Row, Object> function12) {
                        return Iterator.filter$(this, function12);
                    }

                    public <B> boolean corresponds(GenTraversableOnce<B> genTraversableOnce, Function2<Row, B, Object> function2) {
                        return Iterator.corresponds$(this, genTraversableOnce, function2);
                    }

                    public Iterator<Row> withFilter(Function1<Row, Object> function12) {
                        return Iterator.withFilter$(this, function12);
                    }

                    public Iterator<Row> filterNot(Function1<Row, Object> function12) {
                        return Iterator.filterNot$(this, function12);
                    }

                    public <B> Iterator<B> collect(PartialFunction<Row, B> partialFunction) {
                        return Iterator.collect$(this, partialFunction);
                    }

                    public <B> Iterator<B> scanLeft(B b, Function2<B, Row, B> function2) {
                        return Iterator.scanLeft$(this, b, function2);
                    }

                    public <B> Iterator<B> scanRight(B b, Function2<Row, B, B> function2) {
                        return Iterator.scanRight$(this, b, function2);
                    }

                    public Iterator<Row> takeWhile(Function1<Row, Object> function12) {
                        return Iterator.takeWhile$(this, function12);
                    }

                    public Tuple2<Iterator<Row>, Iterator<Row>> partition(Function1<Row, Object> function12) {
                        return Iterator.partition$(this, function12);
                    }

                    public Tuple2<Iterator<Row>, Iterator<Row>> span(Function1<Row, Object> function12) {
                        return Iterator.span$(this, function12);
                    }

                    public Iterator<Row> dropWhile(Function1<Row, Object> function12) {
                        return Iterator.dropWhile$(this, function12);
                    }

                    public <B> Iterator<Tuple2<Row, B>> zip(Iterator<B> iterator2) {
                        return Iterator.zip$(this, iterator2);
                    }

                    public <A1> Iterator<A1> padTo(int i, A1 a1) {
                        return Iterator.padTo$(this, i, a1);
                    }

                    public Iterator<Tuple2<Row, Object>> zipWithIndex() {
                        return Iterator.zipWithIndex$(this);
                    }

                    public <B, A1, B1> Iterator<Tuple2<A1, B1>> zipAll(Iterator<B> iterator2, A1 a1, B1 b1) {
                        return Iterator.zipAll$(this, iterator2, a1, b1);
                    }

                    public <U> void foreach(Function1<Row, U> function12) {
                        Iterator.foreach$(this, function12);
                    }

                    public boolean forall(Function1<Row, Object> function12) {
                        return Iterator.forall$(this, function12);
                    }

                    public boolean exists(Function1<Row, Object> function12) {
                        return Iterator.exists$(this, function12);
                    }

                    public boolean contains(Object obj) {
                        return Iterator.contains$(this, obj);
                    }

                    public Option<Row> find(Function1<Row, Object> function12) {
                        return Iterator.find$(this, function12);
                    }

                    public int indexWhere(Function1<Row, Object> function12) {
                        return Iterator.indexWhere$(this, function12);
                    }

                    public int indexWhere(Function1<Row, Object> function12, int i) {
                        return Iterator.indexWhere$(this, function12, i);
                    }

                    public <B> int indexOf(B b) {
                        return Iterator.indexOf$(this, b);
                    }

                    public <B> int indexOf(B b, int i) {
                        return Iterator.indexOf$(this, b, i);
                    }

                    public BufferedIterator<Row> buffered() {
                        return Iterator.buffered$(this);
                    }

                    public <B> Iterator<Row>.GroupedIterator<B> grouped(int i) {
                        return Iterator.grouped$(this, i);
                    }

                    public <B> Iterator<Row>.GroupedIterator<B> sliding(int i, int i2) {
                        return Iterator.sliding$(this, i, i2);
                    }

                    public <B> int sliding$default$2() {
                        return Iterator.sliding$default$2$(this);
                    }

                    public int length() {
                        return Iterator.length$(this);
                    }

                    public Tuple2<Iterator<Row>, Iterator<Row>> duplicate() {
                        return Iterator.duplicate$(this);
                    }

                    public <B> Iterator<B> patch(int i, Iterator<B> iterator2, int i2) {
                        return Iterator.patch$(this, i, iterator2, i2);
                    }

                    public <B> void copyToArray(Object obj, int i, int i2) {
                        Iterator.copyToArray$(this, obj, i, i2);
                    }

                    public boolean sameElements(Iterator<?> iterator2) {
                        return Iterator.sameElements$(this, iterator2);
                    }

                    /* renamed from: toTraversable, reason: merged with bridge method [inline-methods] */
                    public Traversable<Row> m91toTraversable() {
                        return Iterator.toTraversable$(this);
                    }

                    public Iterator<Row> toIterator() {
                        return Iterator.toIterator$(this);
                    }

                    public Stream<Row> toStream() {
                        return Iterator.toStream$(this);
                    }

                    public String toString() {
                        return Iterator.toString$(this);
                    }

                    public List<Row> reversed() {
                        return TraversableOnce.reversed$(this);
                    }

                    public int size() {
                        return TraversableOnce.size$(this);
                    }

                    public boolean nonEmpty() {
                        return TraversableOnce.nonEmpty$(this);
                    }

                    public int count(Function1<Row, Object> function12) {
                        return TraversableOnce.count$(this, function12);
                    }

                    public <B> Option<B> collectFirst(PartialFunction<Row, B> partialFunction) {
                        return TraversableOnce.collectFirst$(this, partialFunction);
                    }

                    public <B> B $div$colon(B b, Function2<B, Row, B> function2) {
                        return (B) TraversableOnce.$div$colon$(this, b, function2);
                    }

                    public <B> B $colon$bslash(B b, Function2<Row, B, B> function2) {
                        return (B) TraversableOnce.$colon$bslash$(this, b, function2);
                    }

                    public <B> B foldLeft(B b, Function2<B, Row, B> function2) {
                        return (B) TraversableOnce.foldLeft$(this, b, function2);
                    }

                    public <B> B foldRight(B b, Function2<Row, B, B> function2) {
                        return (B) TraversableOnce.foldRight$(this, b, function2);
                    }

                    public <B> B reduceLeft(Function2<B, Row, B> function2) {
                        return (B) TraversableOnce.reduceLeft$(this, function2);
                    }

                    public <B> B reduceRight(Function2<Row, B, B> function2) {
                        return (B) TraversableOnce.reduceRight$(this, function2);
                    }

                    public <B> Option<B> reduceLeftOption(Function2<B, Row, B> function2) {
                        return TraversableOnce.reduceLeftOption$(this, function2);
                    }

                    public <B> Option<B> reduceRightOption(Function2<Row, B, B> function2) {
                        return TraversableOnce.reduceRightOption$(this, function2);
                    }

                    public <A1> A1 reduce(Function2<A1, A1, A1> function2) {
                        return (A1) TraversableOnce.reduce$(this, function2);
                    }

                    public <A1> Option<A1> reduceOption(Function2<A1, A1, A1> function2) {
                        return TraversableOnce.reduceOption$(this, function2);
                    }

                    public <A1> A1 fold(A1 a1, Function2<A1, A1, A1> function2) {
                        return (A1) TraversableOnce.fold$(this, a1, function2);
                    }

                    public <B> B aggregate(Function0<B> function0, Function2<B, Row, B> function2, Function2<B, B, B> function22) {
                        return (B) TraversableOnce.aggregate$(this, function0, function2, function22);
                    }

                    public <B> B sum(Numeric<B> numeric) {
                        return (B) TraversableOnce.sum$(this, numeric);
                    }

                    public <B> B product(Numeric<B> numeric) {
                        return (B) TraversableOnce.product$(this, numeric);
                    }

                    public Object min(Ordering ordering) {
                        return TraversableOnce.min$(this, ordering);
                    }

                    public Object max(Ordering ordering) {
                        return TraversableOnce.max$(this, ordering);
                    }

                    public Object maxBy(Function1 function12, Ordering ordering) {
                        return TraversableOnce.maxBy$(this, function12, ordering);
                    }

                    public Object minBy(Function1 function12, Ordering ordering) {
                        return TraversableOnce.minBy$(this, function12, ordering);
                    }

                    public <B> void copyToBuffer(Buffer<B> buffer) {
                        TraversableOnce.copyToBuffer$(this, buffer);
                    }

                    public <B> void copyToArray(Object obj, int i) {
                        TraversableOnce.copyToArray$(this, obj, i);
                    }

                    public <B> void copyToArray(Object obj) {
                        TraversableOnce.copyToArray$(this, obj);
                    }

                    public <B> Object toArray(ClassTag<B> classTag) {
                        return TraversableOnce.toArray$(this, classTag);
                    }

                    public List<Row> toList() {
                        return TraversableOnce.toList$(this);
                    }

                    /* renamed from: toIterable, reason: merged with bridge method [inline-methods] */
                    public Iterable<Row> m90toIterable() {
                        return TraversableOnce.toIterable$(this);
                    }

                    /* renamed from: toSeq, reason: merged with bridge method [inline-methods] */
                    public Seq<Row> m89toSeq() {
                        return TraversableOnce.toSeq$(this);
                    }

                    public IndexedSeq<Row> toIndexedSeq() {
                        return TraversableOnce.toIndexedSeq$(this);
                    }

                    public <B> Buffer<B> toBuffer() {
                        return TraversableOnce.toBuffer$(this);
                    }

                    /* renamed from: toSet, reason: merged with bridge method [inline-methods] */
                    public <B> Set<B> m88toSet() {
                        return TraversableOnce.toSet$(this);
                    }

                    public Vector<Row> toVector() {
                        return TraversableOnce.toVector$(this);
                    }

                    public <Col> Col to(CanBuildFrom<Nothing$, Row, Col> canBuildFrom) {
                        return (Col) TraversableOnce.to$(this, canBuildFrom);
                    }

                    /* renamed from: toMap, reason: merged with bridge method [inline-methods] */
                    public <T, U> Map<T, U> m87toMap(Predef$.less.colon.less<Row, Tuple2<T, U>> lessVar) {
                        return TraversableOnce.toMap$(this, lessVar);
                    }

                    public String mkString(String str2, String str3, String str4) {
                        return TraversableOnce.mkString$(this, str2, str3, str4);
                    }

                    public String mkString(String str2) {
                        return TraversableOnce.mkString$(this, str2);
                    }

                    public String mkString() {
                        return TraversableOnce.mkString$(this);
                    }

                    public StringBuilder addString(StringBuilder stringBuilder, String str2, String str3, String str4) {
                        return TraversableOnce.addString$(this, stringBuilder, str2, str3, str4);
                    }

                    public StringBuilder addString(StringBuilder stringBuilder, String str2) {
                        return TraversableOnce.addString$(this, stringBuilder, str2);
                    }

                    public StringBuilder addString(StringBuilder stringBuilder) {
                        return TraversableOnce.addString$(this, stringBuilder);
                    }

                    public int sizeHintIfCheap() {
                        return GenTraversableOnce.sizeHintIfCheap$(this);
                    }

                    private int batchCnt() {
                        return this.batchCnt;
                    }

                    private void batchCnt_$eq(int i) {
                        this.batchCnt = i;
                    }

                    private RowConverter converter() {
                        return this.converter;
                    }

                    private void converter_$eq(RowConverter rowConverter) {
                        this.converter = rowConverter;
                    }

                    private ColumnarBatch cb() {
                        return this.cb;
                    }

                    private void cb_$eq(ColumnarBatch columnarBatch) {
                        this.cb = columnarBatch;
                    }

                    private Iterator<Row> it() {
                        return this.it;
                    }

                    private void it_$eq(Iterator<Row> iterator2) {
                        this.it = iterator2;
                    }

                    private void closeCurrentBatch() {
                        if (cb() != null) {
                            cb().close();
                            cb_$eq(null);
                        }
                    }

                    private void loadNextBatch() {
                        Iterator<Row> iterator2;
                        closeCurrentBatch();
                        if (it() != null) {
                            it_$eq(null);
                        }
                        if (this.iter$1.hasNext()) {
                            Table table = (Table) this.iter$1.next();
                            if (batchCnt() == 0) {
                                Rabit.init((java.util.Map) JavaConverters$.MODULE$.mapAsJavaMapConverter(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("DMLC_TASK_ID"), BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString())}))).asJava());
                                converter_$eq(new RowConverter((StructType) this.bOrigSchema$1.value(), (Seq) RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), table.getNumberOfColumns()).map(obj -> {
                                    return $anonfun$loadNextBatch$1(table, BoxesRunTime.unboxToInt(obj));
                                }, IndexedSeq$.MODULE$.canBuildFrom())));
                            }
                            ColumnarBatch from = GpuColumnVector.from(table, (DataType[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(((StructType) this.bOrigSchema$1.value()).fields())).map(structField -> {
                                return structField.dataType();
                            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(DataType.class))));
                            try {
                                cb_$eq(new ColumnarBatch((ColumnVector[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(GpuColumnVector.extractColumns(from))).map(gpuColumnVector -> {
                                    return gpuColumnVector.copyToHost();
                                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(ColumnVector.class))), from.numRows()));
                                Iterator map = ((Iterator) JavaConverters$.MODULE$.asScalaIteratorConverter(cb().rowIterator()).asScala()).map(this.toUnsafe$1).map(internalRow -> {
                                    return this.converter().toExternalRow(internalRow);
                                });
                                ColumnDMatrix columnDMatrix = new ColumnDMatrix(new GpuColumnBatch(table, (StructType) this.bOrigSchema$1.value()).getArrayInterface((int[]) this.featureIds$1.toArray(ClassTag$.MODULE$.Int())), this.missing$1, 1);
                                if (columnDMatrix == null) {
                                    iterator2 = Iterator$.MODULE$.empty();
                                } else {
                                    try {
                                        int gpuId = GpuUtils$.MODULE$.getGpuId(this.isLocal$1);
                                        ((Booster) this.bBooster$1.value()).setParam("gpu_id", BoxesRunTime.boxToInteger(gpuId).toString());
                                        ((Booster) this.bBooster$1.value()).setParam("predictor", "gpu_predictor");
                                        GpuTransform$.MODULE$.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger().info(new StringBuilder(45).append("XGBoost transform GPU pipeline using device: ").append(gpuId).toString());
                                        iterator2 = (Iterator) this.predictFunc$1.apply(this.bBooster$1, columnDMatrix, map);
                                    } finally {
                                        columnDMatrix.delete();
                                    }
                                }
                                it_$eq(iterator2);
                            } finally {
                                batchCnt_$eq(batchCnt() + 1);
                                from.close();
                                table.close();
                            }
                        }
                    }

                    public boolean hasNext() {
                        boolean z2 = it() != null && it().hasNext();
                        if (z2) {
                            return z2;
                        }
                        loadNextBatch();
                        return it() != null && it().hasNext();
                    }

                    /* renamed from: next, reason: merged with bridge method [inline-methods] */
                    public Row m93next() {
                        if (it() == null || !it().hasNext()) {
                            loadNextBatch();
                        }
                        if (it() == null) {
                            throw new NoSuchElementException();
                        }
                        return (Row) it().next();
                    }

                    public static final /* synthetic */ void $anonfun$new$1(GpuTransform$$anon$1 gpuTransform$$anon$1, TaskContext taskContext) {
                        if (gpuTransform$$anon$1.batchCnt() > 0) {
                            Rabit.shutdown();
                        }
                        gpuTransform$$anon$1.closeCurrentBatch();
                    }

                    public static final /* synthetic */ DType $anonfun$loadNextBatch$1(Table table, int i) {
                        return table.getColumn(i).getType();
                    }

                    {
                        this.iter$1 = iterator;
                        this.bOrigSchema$1 = broadcast;
                        this.toUnsafe$1 = create;
                        this.featureIds$1 = seq3;
                        this.missing$1 = f;
                        this.isLocal$1 = isLocal;
                        this.bBooster$1 = broadcast3;
                        this.predictFunc$1 = function3;
                        GenTraversableOnce.$init$(this);
                        TraversableOnce.$init$(this);
                        Iterator.$init$(this);
                        this.batchCnt = 0;
                        this.converter = null;
                        this.cb = null;
                        this.it = null;
                        TaskContext$.MODULE$.get().addTaskCompletionListener(taskContext -> {
                            $anonfun$new$1(this, taskContext);
                            return BoxedUnit.UNIT;
                        });
                    }
                };
            }
            Rabit.init((java.util.Map) JavaConverters$.MODULE$.mapAsJavaMapConverter(Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("DMLC_TASK_ID"), BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("DMLC_WORKER_STOP_PROCESS_ON_ERROR"), "false")}))).asJava());
            try {
                Iterator map = iterator.map(table -> {
                    return new GpuColumnBatch(table, (StructType) broadcast.value(), (GpuSampler) option.getOrElse(() -> {
                        return null;
                    }));
                });
                Tuple2 time = MLUtils$.MODULE$.time(() -> {
                    return GpuUtils$.MODULE$.buildDMatrixAndColumnToRow(f, map, seq3, (StructType) broadcast2.value());
                });
                if (time != null) {
                    Tuple2 tuple2 = (Tuple2) time._1();
                    float unboxToFloat = BoxesRunTime.unboxToFloat(time._2());
                    if (tuple2 != null) {
                        Tuple3 tuple3 = new Tuple3((DMatrix) tuple2._1(), (ColumnBatchToRow) tuple2._2(), BoxesRunTime.boxToFloat(unboxToFloat));
                        DMatrix dMatrix = (DMatrix) tuple3._1();
                        ColumnBatchToRow columnBatchToRow = (ColumnBatchToRow) tuple3._2();
                        float unboxToFloat2 = BoxesRunTime.unboxToFloat(tuple3._3());
                        if (dMatrix == null) {
                            MODULE$.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger().info("No data when building DMatrix");
                            iterator = Iterator$.MODULE$.empty();
                        } else {
                            MODULE$.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger().debug(new StringBuilder(45).append("Benchmark [Transform: Build Dmatrix and Row] ").append(unboxToFloat2).toString());
                            try {
                                int gpuId = GpuUtils$.MODULE$.getGpuId(isLocal);
                                ((Booster) broadcast3.value()).setParam("gpu_id", BoxesRunTime.boxToInteger(gpuId).toString());
                                ((Booster) broadcast3.value()).setParam("predictor", "gpu_predictor");
                                MODULE$.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger().info(new StringBuilder(45).append("XGBoost transform GPU pipeline using device: ").append(gpuId).toString());
                                iterator = (Iterator) function3.apply(broadcast3, dMatrix, columnBatchToRow.toIterator());
                            } finally {
                                dMatrix.delete();
                            }
                        }
                        return iterator;
                    }
                }
                throw new MatchError(time);
            } finally {
                Rabit.shutdown();
            }
        }, columnarRdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Row.class));
        broadcast.unpersist(false);
        broadcast2.unpersist(false);
        broadcast3.unpersist(false);
        return dataset.sparkSession().createDataFrame(mapPartitions, (StructType) function1.apply(buildRowSchema));
    }

    private boolean transformDistributed$default$8() {
        return true;
    }

    private Option<GpuSampler> transformDistributed$default$9() {
        return None$.MODULE$;
    }

    public static final /* synthetic */ void $anonfun$buildRddSchema$1(ObjectRef objectRef, ArrayType arrayType, String str) {
        if (new StringOps(Predef$.MODULE$.augmentString(str)).nonEmpty()) {
            objectRef.elem = ((StructType) objectRef.elem).add(str, arrayType, false);
        }
    }

    private GpuTransform$() {
        MODULE$ = this;
        this.ml$dmlc$xgboost4j$scala$spark$rapids$GpuTransform$$logger = LogFactory.getLog("GpuTransform");
    }
}
