package ml.dmlc.xgboost4j.scala.spark.rapids;

import ai.rapids.cudf.Table;
import java.lang.Thread;
import ml.dmlc.xgboost4j.java.Rabit;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.spark.rapids.GpuColumnBatch;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.ColumnDMatrix;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointParams;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.spark.Watches;
import ml.dmlc.xgboost4j.scala.spark.XGBoost$;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParams;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParamsFactory;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor;
import ml.dmlc.xgboost4j.scala.spark.XGBoostTrainingSummary$;
import ml.dmlc.xgboost4j.scala.spark.params.BoosterParams$;
import ml.dmlc.xgboost4j.scala.spark.rapids.GpuXGBoost;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.ml.param.Param;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.LongAccumulator;
import scala.Array$;
import scala.Double$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.PartialFunction;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.collection.BufferedIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.collection.Iterator;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.Traversable;
import scala.collection.TraversableOnce;
import scala.collection.generic.CanBuildFrom;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.collection.immutable.StringOps;
import scala.collection.immutable.Vector;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.package$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.Nothing$;
import scala.runtime.RichFloat$;
import scala.runtime.ScalaRunTime$;

/* compiled from: GpuXGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/rapids/GpuXGBoost$.class */
public final class GpuXGBoost$ {
    public static GpuXGBoost$ MODULE$;
    private final String trainName;
    private final Log logger;

    static {
        new GpuXGBoost$();
    }

    public String trainName() {
        return this.trainName;
    }

    private Log logger() {
        return this.logger;
    }

    public XGBoostClassificationModel fitOnGpu(XGBoostClassifier xGBoostClassifier, Dataset<?> dataset, Option<GpuSampler> option) {
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostClassifier, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostClassifier.labelCol(), xGBoostClassifier.weightCol(), xGBoostClassifier.baseMarginCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple3 tuple3 = new Tuple3((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2));
        return trainOnGpu(xGBoostClassifier, MLUtils$.MODULE$.prepareColumnType(dataset, xGBoostClassifier.getFeaturesCols(), (String) tuple3._1(), (String) tuple3._2(), (String) tuple3._3(), MLUtils$.MODULE$.prepareColumnType$default$6()), option).setParent(xGBoostClassifier).m21copy(xGBoostClassifier.extractParamMap());
    }

    public XGBoostClassificationModel trainOnGpu(XGBoostClassifier xGBoostClassifier, Dataset<Row> dataset, Option<GpuSampler> option) {
        Predef$.MODULE$.require(xGBoostClassifier.isDefined(xGBoostClassifier.objective()), () -> {
            return "Parameter 'objective' must be set.";
        });
        if (!xGBoostClassifier.isDefined(xGBoostClassifier.evalMetric()) || xGBoostClassifier.getEvalMetric().isEmpty()) {
            xGBoostClassifier.setEvalMetric(xGBoostClassifier.getObjective().startsWith("multi") ? "merror" : "error");
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        if (!xGBoostClassifier.isDefined(xGBoostClassifier.customObj()) || xGBoostClassifier.getOrDefault(xGBoostClassifier.customObj()) == null) {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            xGBoostClassifier.setObjectiveType("classification");
        }
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostClassifier, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostClassifier.labelCol(), xGBoostClassifier.weightCol(), xGBoostClassifier.baseMarginCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple3 tuple3 = new Tuple3((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2));
        String str = (String) tuple3._1();
        String str2 = (String) tuple3._2();
        String str3 = (String) tuple3._3();
        ColumnDataBatch buildColumnDataBatch = GpuUtils$.MODULE$.buildColumnDataBatch(xGBoostClassifier.getFeaturesCols(), str, str2, str3, "", dataset);
        Map<String, ColumnDataBatch> map = (Map) xGBoostClassifier.getEvalSets(xGBoostClassifier.getUserParams()).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            return new Tuple2((String) tuple2._1(), GpuUtils$.MODULE$.buildColumnDataBatch(xGBoostClassifier.getFeaturesCols(), str, str2, str3, "", MLUtils$.MODULE$.prepareColumnType((Dataset) tuple2._2(), xGBoostClassifier.getFeaturesCols(), str, str2, str3, MLUtils$.MODULE$.prepareColumnType$default$6())));
        }, Map$.MODULE$.canBuildFrom());
        LongAccumulator longAccumulator = dataset.sparkSession().sparkContext().longAccumulator("Class Number");
        Tuple2<Booster, Map<String, float[]>> trainDistributedOnGpu = trainDistributedOnGpu(buildColumnDataBatch, xGBoostClassifier.MLlib2XGBoostParams(), map, longAccumulator, option);
        if (trainDistributedOnGpu == null) {
            throw new MatchError(trainDistributedOnGpu);
        }
        Tuple2 tuple22 = new Tuple2((Booster) trainDistributedOnGpu._1(), (Map) trainDistributedOnGpu._2());
        Booster booster = (Booster) tuple22._1();
        Map<String, float[]> map2 = (Map) tuple22._2();
        int Long2long = (int) Predef$.MODULE$.Long2long(longAccumulator.value());
        logger().debug(new StringBuilder(41).append("Accumulator returns the number of class: ").append(Long2long).toString());
        return new XGBoostClassificationModel(xGBoostClassifier.uid(), Long2long, booster).setSummary(XGBoostTrainingSummary$.MODULE$.apply(map2));
    }

    public XGBoostRegressionModel fitOnGpu(XGBoostRegressor xGBoostRegressor, Dataset<?> dataset, Option<GpuSampler> option) {
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostRegressor, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostRegressor.labelCol(), xGBoostRegressor.weightCol(), xGBoostRegressor.baseMarginCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(3) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple3 tuple3 = new Tuple3((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2));
        return trainOnGpu(xGBoostRegressor, MLUtils$.MODULE$.prepareColumnType(dataset, xGBoostRegressor.getFeaturesCols(), (String) tuple3._1(), (String) tuple3._2(), (String) tuple3._3(), MLUtils$.MODULE$.prepareColumnType$default$6()), option).setParent(xGBoostRegressor).m41copy(xGBoostRegressor.extractParamMap());
    }

    public XGBoostRegressionModel trainOnGpu(XGBoostRegressor xGBoostRegressor, Dataset<Row> dataset, Option<GpuSampler> option) {
        Predef$.MODULE$.require(xGBoostRegressor.isDefined(xGBoostRegressor.objective()), () -> {
            return "Parameter 'objective' must be set.";
        });
        if (!xGBoostRegressor.isDefined(xGBoostRegressor.evalMetric()) || xGBoostRegressor.getEvalMetric().isEmpty()) {
            xGBoostRegressor.setEvalMetric(xGBoostRegressor.getObjective().startsWith("rank") ? "map" : "rmse");
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        if (!xGBoostRegressor.isDefined(xGBoostRegressor.customObj()) || xGBoostRegressor.getOrDefault(xGBoostRegressor.customObj()) == null) {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
        } else {
            xGBoostRegressor.setObjectiveType("regression");
        }
        Seq<String> columnNames = MLUtils$.MODULE$.getColumnNames(xGBoostRegressor, Predef$.MODULE$.wrapRefArray(new Param[]{xGBoostRegressor.labelCol(), xGBoostRegressor.weightCol(), xGBoostRegressor.baseMarginCol(), xGBoostRegressor.groupCol()}));
        Some unapplySeq = Seq$.MODULE$.unapplySeq(columnNames);
        if (unapplySeq.isEmpty() || unapplySeq.get() == null || ((SeqLike) unapplySeq.get()).lengthCompare(4) != 0) {
            throw new MatchError(columnNames);
        }
        Tuple4 tuple4 = new Tuple4((String) ((SeqLike) unapplySeq.get()).apply(0), (String) ((SeqLike) unapplySeq.get()).apply(1), (String) ((SeqLike) unapplySeq.get()).apply(2), (String) ((SeqLike) unapplySeq.get()).apply(3));
        String str = (String) tuple4._1();
        String str2 = (String) tuple4._2();
        String str3 = (String) tuple4._3();
        String str4 = (String) tuple4._4();
        Tuple2<Booster, Map<String, float[]>> trainDistributedOnGpu = trainDistributedOnGpu(GpuUtils$.MODULE$.buildColumnDataBatch(xGBoostRegressor.getFeaturesCols(), str, str2, str3, str4, dataset), xGBoostRegressor.MLlib2XGBoostParams(), (Map) xGBoostRegressor.getEvalSets(xGBoostRegressor.getUserParams()).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            return new Tuple2((String) tuple2._1(), GpuUtils$.MODULE$.buildColumnDataBatch(xGBoostRegressor.getFeaturesCols(), str, str2, str3, str4, MLUtils$.MODULE$.prepareColumnType((Dataset) tuple2._2(), xGBoostRegressor.getFeaturesCols(), str, str2, str3, MLUtils$.MODULE$.prepareColumnType$default$6())));
        }, Map$.MODULE$.canBuildFrom()), null, option);
        if (trainDistributedOnGpu == null) {
            throw new MatchError(trainDistributedOnGpu);
        }
        Tuple2 tuple22 = new Tuple2((Booster) trainDistributedOnGpu._1(), (Map) trainDistributedOnGpu._2());
        Booster booster = (Booster) tuple22._1();
        Map<String, float[]> map = (Map) tuple22._2();
        XGBoostRegressionModel xGBoostRegressionModel = new XGBoostRegressionModel(xGBoostRegressor.uid(), booster);
        xGBoostRegressionModel.setSummary(XGBoostTrainingSummary$.MODULE$.apply(map));
        return xGBoostRegressionModel;
    }

    private Tuple2<Booster, Map<String, float[]>> trainDistributedOnGpu(ColumnDataBatch columnDataBatch, Map<String, Object> map, Map<String, ColumnDataBatch> map2, LongAccumulator longAccumulator, Option<GpuSampler> option) throws XGBoostError {
        logger().info(new StringBuilder(37).append("Running GPU XGBoost with parameters:\n").append(map.mkString("\n")).toString());
        SparkContext sparkContext = columnDataBatch.rawDF().sparkSession().sparkContext();
        XGBoostExecutionParams buildXGBRuntimeParams = new XGBoostExecutionParamsFactory(map, sparkContext).buildXGBRuntimeParams();
        Map<String, ColumnDataBatch> prepareInputData = prepareInputData(columnDataBatch, map2, buildXGBRuntimeParams.numWorkers(), buildXGBRuntimeParams.cacheTrainingSet());
        Booster booster = (Booster) buildXGBRuntimeParams.checkpointParam().map(externalCheckpointParams -> {
            ExternalCheckpointManager externalCheckpointManager = new ExternalCheckpointManager(externalCheckpointParams.checkpointPath(), FileSystem.get(sparkContext.hadoopConfiguration()));
            externalCheckpointManager.cleanUpHigherVersions(buildXGBRuntimeParams.numRounds());
            return externalCheckpointManager.loadCheckpointAsScalaBooster();
        }).orNull(Predef$.MODULE$.$conforms());
        try {
            Thread.UncaughtExceptionHandler startTracker = XGBoost$.MODULE$.startTracker(buildXGBRuntimeParams.numWorkers(), buildXGBRuntimeParams.trackerConf());
            try {
                GpuParallelismTracker gpuParallelismTracker = new GpuParallelismTracker(sparkContext, buildXGBRuntimeParams.timeoutRequestWorkers(), buildXGBRuntimeParams.numWorkers());
                final RDD<Tuple2<Booster, Map<String, float[]>>> trainOnGpuInternal = trainOnGpuInternal(prepareInputData, buildXGBRuntimeParams, startTracker.getWorkerEnvs(), booster, map2.isEmpty(), longAccumulator, option);
                Thread thread = new Thread(trainOnGpuInternal) { // from class: ml.dmlc.xgboost4j.scala.spark.rapids.GpuXGBoost$$anon$1
                    private final RDD boostersAndMetrics$1;

                    @Override // java.lang.Thread, java.lang.Runnable
                    public void run() {
                        this.boostersAndMetrics$1.foreachPartition(iterator -> {
                            () -> {
                                return iterator;
                            };
                            return BoxedUnit.UNIT;
                        });
                    }

                    {
                        this.boostersAndMetrics$1 = trainOnGpuInternal;
                    }
                };
                thread.setUncaughtExceptionHandler(startTracker);
                thread.start();
                int unboxToInt = BoxesRunTime.unboxToInt(gpuParallelismTracker.executeOnGpu(() -> {
                    return startTracker.waitFor(0L);
                }));
                logger().info(new StringBuilder(41).append("GPU XGBoost Rabit returns with exit code ").append(unboxToInt).toString());
                Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing = XGBoost$.MODULE$.postTrackerReturnProcessing(unboxToInt, trainOnGpuInternal, thread);
                if (postTrackerReturnProcessing == null) {
                    throw new MatchError(postTrackerReturnProcessing);
                }
                Tuple2 tuple2 = new Tuple2((Booster) postTrackerReturnProcessing._1(), (Map) postTrackerReturnProcessing._2());
                Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
                if (tuple22 == null) {
                    throw new MatchError(tuple22);
                }
                Tuple2 tuple23 = new Tuple2((Booster) tuple22._1(), (Map) tuple22._2());
                Booster booster2 = (Booster) tuple23._1();
                Map map3 = (Map) tuple23._2();
                buildXGBRuntimeParams.checkpointParam().foreach(externalCheckpointParams2 -> {
                    $anonfun$trainDistributedOnGpu$3(buildXGBRuntimeParams, sparkContext, externalCheckpointParams2);
                    return BoxedUnit.UNIT;
                });
                return new Tuple2<>(booster2, map3);
            } finally {
                startTracker.stop();
            }
        } catch (Throwable th) {
            logger().error("The job was aborted due to ", th);
            sparkContext.stop();
            throw th;
        }
    }

    private Option<GpuSampler> trainDistributedOnGpu$default$5() {
        return None$.MODULE$;
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainOnGpuInternal(Map<String, ColumnDataBatch> map, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map2, Booster booster, boolean z, LongAccumulator longAccumulator, Option<GpuSampler> option) throws XGBoostError {
        XGBoostExecutionParams overrideParamsToUseGPU = overrideParamsToUseGPU(xGBoostExecutionParams);
        SparkContext sparkContext = ((ColumnDataBatch) map.apply(trainName())).rawDF().sparkSession().sparkContext();
        boolean isLocal = sparkContext.isLocal();
        if (z) {
            ColumnIndices colIndices = ((ColumnDataBatch) map.apply(trainName())).colIndices();
            RDD<Table> columnarRdd = GpuUtils$.MODULE$.toColumnarRdd(((ColumnDataBatch) map.apply(trainName())).rawDF());
            return columnarRdd.mapPartitions(iterator -> {
                Iterator map3 = iterator.map(table -> {
                    return new GpuColumnBatch(table, null, (GpuSampler) option.getOrElse(() -> {
                        return null;
                    }));
                });
                XGBoostExecutionParams appendGpuIdToParams = MODULE$.appendGpuIdToParams(overrideParamsToUseGPU, isLocal);
                int unboxToInt = BoxesRunTime.unboxToInt(appendGpuIdToParams.toMap().getOrElse("max_bin", () -> {
                    return 16;
                }));
                return MODULE$.buildDistributedBooster(() -> {
                    return MODULE$.buildWatches(XGBoost$.MODULE$.getCacheDirName(appendGpuIdToParams.useExternalMemory()), appendGpuIdToParams.missing(), colIndices, map3, longAccumulator, unboxToInt);
                }, appendGpuIdToParams, map2, appendGpuIdToParams.obj(), appendGpuIdToParams.eval(), booster);
            }, columnarRdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        Map map3 = (Map) map.map(tuple2 -> {
            return new Tuple2(tuple2._1(), ((ColumnDataBatch) tuple2._2()).colIndices());
        }, Map$.MODULE$.canBuildFrom());
        RDD<Tuple2<String, Iterator<GpuColumnBatch>>> coPartitionForGpu = coPartitionForGpu(map, sparkContext, overrideParamsToUseGPU.numWorkers());
        return coPartitionForGpu.mapPartitions(iterator2 -> {
            XGBoostExecutionParams appendGpuIdToParams = MODULE$.appendGpuIdToParams(overrideParamsToUseGPU, isLocal);
            int unboxToInt = BoxesRunTime.unboxToInt(appendGpuIdToParams.toMap().getOrElse("max_bin", () -> {
                return 16;
            }));
            return MODULE$.buildDistributedBooster(() -> {
                return MODULE$.buildWatchesWithEval(XGBoost$.MODULE$.getCacheDirName(appendGpuIdToParams.useExternalMemory()), appendGpuIdToParams.missing(), map3, iterator2, longAccumulator, unboxToInt);
            }, appendGpuIdToParams, map2, appendGpuIdToParams.obj(), appendGpuIdToParams.eval(), booster);
        }, coPartitionForGpu.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private Option<GpuSampler> trainOnGpuInternal$default$7() {
        return None$.MODULE$;
    }

    public Iterator<Tuple2<Booster, Map<String, float[]>>> buildDistributedBooster(Function0<Watches> function0, XGBoostExecutionParams xGBoostExecutionParams, java.util.Map<String, String> map, ObjectiveTrait objectiveTrait, EvalTrait evalTrait, Booster booster) {
        String obj = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString();
        String obj2 = BoxesRunTime.boxToInteger(TaskContext$.MODULE$.get().attemptNumber()).toString();
        map.put("DMLC_TASK_ID", obj);
        map.put("DMLC_NUM_ATTEMPT", obj2);
        map.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false");
        int numRounds = xGBoostExecutionParams.numRounds();
        boolean z = xGBoostExecutionParams.checkpointParam().isDefined() && new StringOps(Predef$.MODULE$.augmentString(obj)).toInt() == 0;
        Watches watches = null;
        try {
            try {
                Rabit.init(map);
                Watches watches2 = (Watches) function0.apply();
                if (((DMatrix) watches2.toMap().apply("train")).rowNum() == 0) {
                    throw new XGBoostError(new StringBuilder(63).append("detected an empty partition in the training data, partition ID:").append(new StringBuilder(1).append(" ").append(TaskContext$.MODULE$.getPartitionId()).toString()).toString());
                }
                checkNumClass(watches2, xGBoostExecutionParams.toMap());
                int numEarlyStoppingRounds = xGBoostExecutionParams.earlyStoppingParams().numEarlyStoppingRounds();
                float[][] fArr = (float[][]) Array$.MODULE$.tabulate(watches2.size(), obj3 -> {
                    return $anonfun$buildDistributedBooster$1(numRounds, BoxesRunTime.unboxToInt(obj3));
                }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
                Iterator<Tuple2<Booster, Map<String, float[]>>> apply = package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(z ? ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.trainAndSaveCheckpoint((DMatrix) watches2.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches2.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster, xGBoostExecutionParams.checkpointParam()) : ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix) watches2.toMap().apply("train"), xGBoostExecutionParams.toMap(), numRounds, watches2.toMap(), fArr, objectiveTrait, evalTrait, numEarlyStoppingRounds, booster)), ((TraversableOnce) watches2.toMap().keys().zip(Predef$.MODULE$.wrapRefArray(fArr), Iterable$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()))}));
                Rabit.shutdown();
                if (watches2 != null) {
                    watches2.delete();
                }
                return apply;
            } catch (XGBoostError e) {
                logger().error(new StringBuilder(43).append("XGBooster worker ").append(obj).append(" has failed ").append(obj2).append(" times due to ").toString(), e);
                throw e;
            }
        } catch (Throwable th) {
            Rabit.shutdown();
            if (0 != 0) {
                watches.delete();
            }
            throw th;
        }
    }

    private RDD<Tuple2<String, Iterator<GpuColumnBatch>>> coPartitionForGpu(Map<String, ColumnDataBatch> map, SparkContext sparkContext, int i) {
        return (RDD) map.foldLeft(sparkContext.parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd, tuple2) -> {
            Tuple2 tuple2 = new Tuple2(rdd, tuple2);
            if (tuple2 != null) {
                RDD rdd = (RDD) tuple2._1();
                Tuple2 tuple22 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    String str = (String) tuple22._1();
                    return rdd.zipPartitions(GpuUtils$.MODULE$.toColumnarRdd(((ColumnDataBatch) tuple22._2()).rawDF()), (iterator, iterator2) -> {
                        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2.map(table -> {
                            return new GpuColumnBatch(table, null);
                        })), ClassTag$.MODULE$.apply(Tuple2.class)))).filter(tuple23 -> {
                            return BoxesRunTime.boxToBoolean($anonfun$coPartitionForGpu$5(tuple23));
                        }))).toIterator();
                    }, ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple2);
        });
    }

    private Map<String, ColumnDataBatch> prepareInputData(ColumnDataBatch columnDataBatch, Map<String, ColumnDataBatch> map, int i, boolean z) {
        if (z) {
            logger().warn("Dataset cache is not support for Gpu pipeline!");
        }
        return (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(trainName()), columnDataBatch)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            ColumnDataBatch columnDataBatch2 = (ColumnDataBatch) tuple2._2();
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), new ColumnDataBatch((Dataset) columnDataBatch2.groupColName().map(str2 -> {
                return MODULE$.repartitionForGroup(str2, columnDataBatch2.rawDF(), i);
            }).getOrElse(() -> {
                return columnDataBatch2.rawDF().repartition(i);
            }), columnDataBatch2.colIndices(), columnDataBatch2.groupColName()));
        }, Map$.MODULE$.canBuildFrom());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Dataset<Row> repartitionForGroup(String str, Dataset<Row> dataset, int i) {
        logger().info("Start groupBy for ltr");
        StructType schema = dataset.schema();
        Dataset agg = dataset.groupBy(str, Predef$.MODULE$.wrapRefArray(new String[0])).agg(functions$.MODULE$.collect_list(functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(schema.fieldNames())).map(str2 -> {
            return functions$.MODULE$.col(str2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))))).as("list"), Predef$.MODULE$.wrapRefArray(new Column[0]));
        return agg.repartition(i).mapPartitions(iterator -> {
            return new Iterator<Row>(iterator) { // from class: ml.dmlc.xgboost4j.scala.spark.rapids.GpuXGBoost$$anon$2
                private Iterator<Object> iterInRow;
                private final Iterator iter$1;

                /* renamed from: seq, reason: merged with bridge method [inline-methods] */
                public Iterator<Row> m105seq() {
                    return Iterator.seq$(this);
                }

                public boolean isEmpty() {
                    return Iterator.isEmpty$(this);
                }

                public boolean isTraversableAgain() {
                    return Iterator.isTraversableAgain$(this);
                }

                public boolean hasDefiniteSize() {
                    return Iterator.hasDefiniteSize$(this);
                }

                public Iterator<Row> take(int i2) {
                    return Iterator.take$(this, i2);
                }

                public Iterator<Row> drop(int i2) {
                    return Iterator.drop$(this, i2);
                }

                public Iterator<Row> slice(int i2, int i3) {
                    return Iterator.slice$(this, i2, i3);
                }

                public Iterator<Row> sliceIterator(int i2, int i3) {
                    return Iterator.sliceIterator$(this, i2, i3);
                }

                public <B> Iterator<B> map(Function1<Row, B> function1) {
                    return Iterator.map$(this, function1);
                }

                public <B> Iterator<B> $plus$plus(Function0<GenTraversableOnce<B>> function0) {
                    return Iterator.$plus$plus$(this, function0);
                }

                public <B> Iterator<B> flatMap(Function1<Row, GenTraversableOnce<B>> function1) {
                    return Iterator.flatMap$(this, function1);
                }

                public Iterator<Row> filter(Function1<Row, Object> function1) {
                    return Iterator.filter$(this, function1);
                }

                public <B> boolean corresponds(GenTraversableOnce<B> genTraversableOnce, Function2<Row, B, Object> function2) {
                    return Iterator.corresponds$(this, genTraversableOnce, function2);
                }

                public Iterator<Row> withFilter(Function1<Row, Object> function1) {
                    return Iterator.withFilter$(this, function1);
                }

                public Iterator<Row> filterNot(Function1<Row, Object> function1) {
                    return Iterator.filterNot$(this, function1);
                }

                public <B> Iterator<B> collect(PartialFunction<Row, B> partialFunction) {
                    return Iterator.collect$(this, partialFunction);
                }

                public <B> Iterator<B> scanLeft(B b, Function2<B, Row, B> function2) {
                    return Iterator.scanLeft$(this, b, function2);
                }

                public <B> Iterator<B> scanRight(B b, Function2<Row, B, B> function2) {
                    return Iterator.scanRight$(this, b, function2);
                }

                public Iterator<Row> takeWhile(Function1<Row, Object> function1) {
                    return Iterator.takeWhile$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> partition(Function1<Row, Object> function1) {
                    return Iterator.partition$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> span(Function1<Row, Object> function1) {
                    return Iterator.span$(this, function1);
                }

                public Iterator<Row> dropWhile(Function1<Row, Object> function1) {
                    return Iterator.dropWhile$(this, function1);
                }

                public <B> Iterator<Tuple2<Row, B>> zip(Iterator<B> iterator) {
                    return Iterator.zip$(this, iterator);
                }

                public <A1> Iterator<A1> padTo(int i2, A1 a1) {
                    return Iterator.padTo$(this, i2, a1);
                }

                public Iterator<Tuple2<Row, Object>> zipWithIndex() {
                    return Iterator.zipWithIndex$(this);
                }

                public <B, A1, B1> Iterator<Tuple2<A1, B1>> zipAll(Iterator<B> iterator, A1 a1, B1 b1) {
                    return Iterator.zipAll$(this, iterator, a1, b1);
                }

                public <U> void foreach(Function1<Row, U> function1) {
                    Iterator.foreach$(this, function1);
                }

                public boolean forall(Function1<Row, Object> function1) {
                    return Iterator.forall$(this, function1);
                }

                public boolean exists(Function1<Row, Object> function1) {
                    return Iterator.exists$(this, function1);
                }

                public boolean contains(Object obj) {
                    return Iterator.contains$(this, obj);
                }

                public Option<Row> find(Function1<Row, Object> function1) {
                    return Iterator.find$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1) {
                    return Iterator.indexWhere$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1, int i2) {
                    return Iterator.indexWhere$(this, function1, i2);
                }

                public <B> int indexOf(B b) {
                    return Iterator.indexOf$(this, b);
                }

                public <B> int indexOf(B b, int i2) {
                    return Iterator.indexOf$(this, b, i2);
                }

                public BufferedIterator<Row> buffered() {
                    return Iterator.buffered$(this);
                }

                public <B> Iterator<Row>.GroupedIterator<B> grouped(int i2) {
                    return Iterator.grouped$(this, i2);
                }

                public <B> Iterator<Row>.GroupedIterator<B> sliding(int i2, int i3) {
                    return Iterator.sliding$(this, i2, i3);
                }

                public <B> int sliding$default$2() {
                    return Iterator.sliding$default$2$(this);
                }

                public int length() {
                    return Iterator.length$(this);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> duplicate() {
                    return Iterator.duplicate$(this);
                }

                public <B> Iterator<B> patch(int i2, Iterator<B> iterator, int i3) {
                    return Iterator.patch$(this, i2, iterator, i3);
                }

                public <B> void copyToArray(Object obj, int i2, int i3) {
                    Iterator.copyToArray$(this, obj, i2, i3);
                }

                public boolean sameElements(Iterator<?> iterator) {
                    return Iterator.sameElements$(this, iterator);
                }

                /* renamed from: toTraversable, reason: merged with bridge method [inline-methods] */
                public Traversable<Row> m104toTraversable() {
                    return Iterator.toTraversable$(this);
                }

                public Iterator<Row> toIterator() {
                    return Iterator.toIterator$(this);
                }

                public Stream<Row> toStream() {
                    return Iterator.toStream$(this);
                }

                public String toString() {
                    return Iterator.toString$(this);
                }

                public List<Row> reversed() {
                    return TraversableOnce.reversed$(this);
                }

                public int size() {
                    return TraversableOnce.size$(this);
                }

                public boolean nonEmpty() {
                    return TraversableOnce.nonEmpty$(this);
                }

                public int count(Function1<Row, Object> function1) {
                    return TraversableOnce.count$(this, function1);
                }

                public <B> Option<B> collectFirst(PartialFunction<Row, B> partialFunction) {
                    return TraversableOnce.collectFirst$(this, partialFunction);
                }

                public <B> B $div$colon(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.$div$colon$(this, b, function2);
                }

                public <B> B $colon$bslash(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.$colon$bslash$(this, b, function2);
                }

                public <B> B foldLeft(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.foldLeft$(this, b, function2);
                }

                public <B> B foldRight(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.foldRight$(this, b, function2);
                }

                public <B> B reduceLeft(Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.reduceLeft$(this, function2);
                }

                public <B> B reduceRight(Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.reduceRight$(this, function2);
                }

                public <B> Option<B> reduceLeftOption(Function2<B, Row, B> function2) {
                    return TraversableOnce.reduceLeftOption$(this, function2);
                }

                public <B> Option<B> reduceRightOption(Function2<Row, B, B> function2) {
                    return TraversableOnce.reduceRightOption$(this, function2);
                }

                public <A1> A1 reduce(Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.reduce$(this, function2);
                }

                public <A1> Option<A1> reduceOption(Function2<A1, A1, A1> function2) {
                    return TraversableOnce.reduceOption$(this, function2);
                }

                public <A1> A1 fold(A1 a1, Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.fold$(this, a1, function2);
                }

                public <B> B aggregate(Function0<B> function0, Function2<B, Row, B> function2, Function2<B, B, B> function22) {
                    return (B) TraversableOnce.aggregate$(this, function0, function2, function22);
                }

                public <B> B sum(Numeric<B> numeric) {
                    return (B) TraversableOnce.sum$(this, numeric);
                }

                public <B> B product(Numeric<B> numeric) {
                    return (B) TraversableOnce.product$(this, numeric);
                }

                public Object min(Ordering ordering) {
                    return TraversableOnce.min$(this, ordering);
                }

                public Object max(Ordering ordering) {
                    return TraversableOnce.max$(this, ordering);
                }

                public Object maxBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.maxBy$(this, function1, ordering);
                }

                public Object minBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.minBy$(this, function1, ordering);
                }

                public <B> void copyToBuffer(Buffer<B> buffer) {
                    TraversableOnce.copyToBuffer$(this, buffer);
                }

                public <B> void copyToArray(Object obj, int i2) {
                    TraversableOnce.copyToArray$(this, obj, i2);
                }

                public <B> void copyToArray(Object obj) {
                    TraversableOnce.copyToArray$(this, obj);
                }

                public <B> Object toArray(ClassTag<B> classTag) {
                    return TraversableOnce.toArray$(this, classTag);
                }

                public List<Row> toList() {
                    return TraversableOnce.toList$(this);
                }

                /* renamed from: toIterable, reason: merged with bridge method [inline-methods] */
                public Iterable<Row> m103toIterable() {
                    return TraversableOnce.toIterable$(this);
                }

                /* renamed from: toSeq, reason: merged with bridge method [inline-methods] */
                public Seq<Row> m102toSeq() {
                    return TraversableOnce.toSeq$(this);
                }

                public IndexedSeq<Row> toIndexedSeq() {
                    return TraversableOnce.toIndexedSeq$(this);
                }

                public <B> Buffer<B> toBuffer() {
                    return TraversableOnce.toBuffer$(this);
                }

                /* renamed from: toSet, reason: merged with bridge method [inline-methods] */
                public <B> Set<B> m101toSet() {
                    return TraversableOnce.toSet$(this);
                }

                public Vector<Row> toVector() {
                    return TraversableOnce.toVector$(this);
                }

                public <Col> Col to(CanBuildFrom<Nothing$, Row, Col> canBuildFrom) {
                    return (Col) TraversableOnce.to$(this, canBuildFrom);
                }

                /* renamed from: toMap, reason: merged with bridge method [inline-methods] */
                public <T, U> Map<T, U> m100toMap(Predef$.less.colon.less<Row, Tuple2<T, U>> lessVar) {
                    return TraversableOnce.toMap$(this, lessVar);
                }

                public String mkString(String str3, String str4, String str5) {
                    return TraversableOnce.mkString$(this, str3, str4, str5);
                }

                public String mkString(String str3) {
                    return TraversableOnce.mkString$(this, str3);
                }

                public String mkString() {
                    return TraversableOnce.mkString$(this);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3, String str4, String str5) {
                    return TraversableOnce.addString$(this, stringBuilder, str3, str4, str5);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3) {
                    return TraversableOnce.addString$(this, stringBuilder, str3);
                }

                public StringBuilder addString(StringBuilder stringBuilder) {
                    return TraversableOnce.addString$(this, stringBuilder);
                }

                public int sizeHintIfCheap() {
                    return GenTraversableOnce.sizeHintIfCheap$(this);
                }

                private Iterator<Object> iterInRow() {
                    return this.iterInRow;
                }

                private void iterInRow_$eq(Iterator<Object> iterator) {
                    this.iterInRow = iterator;
                }

                public boolean hasNext() {
                    if (this.iter$1.hasNext() && !iterInRow().hasNext()) {
                        iterInRow_$eq(((Row) this.iter$1.next()).getSeq(1).iterator());
                    }
                    return iterInRow().hasNext();
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Row m106next() {
                    return (Row) iterInRow().next();
                }

                {
                    this.iter$1 = iterator;
                    GenTraversableOnce.$init$(this);
                    TraversableOnce.$init$(this);
                    Iterator.$init$(this);
                    this.iterInRow = package$.MODULE$.Iterator().empty();
                }
            };
        }, RowEncoder$.MODULE$.apply(schema));
    }

    private XGBoostExecutionParams overrideParamsToUseGPU(XGBoostExecutionParams xGBoostExecutionParams) {
        Map<String, Object> map = xGBoostExecutionParams.toMap();
        if (map.contains("tree_method")) {
            String str = (String) map.apply("tree_method");
            if (str != null ? !str.equals("auto") : "auto" != 0) {
                Predef$.MODULE$.require(str.startsWith("gpu_"), () -> {
                    return new StringBuilder(68).append("Now for training on GPU, xgboost-spark only supports tree_method as ").append(new StringBuilder(2).append("[").append(BoosterParams$.MODULE$.supportedTreeMethods().filter(str2 -> {
                        return BoxesRunTime.boxToBoolean(str2.startsWith("gpu_"));
                    }).mkString(", ")).append("]").toString()).append(new StringBuilder(14).append(", but found '").append(str).append("'").toString()).toString();
                });
            } else {
                map = map.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("tree_method"), "gpu_hist"));
            }
        } else {
            map = map.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("tree_method"), "gpu_hist"));
        }
        xGBoostExecutionParams.setRawParamMap(map);
        return xGBoostExecutionParams;
    }

    private XGBoostExecutionParams appendGpuIdToParams(XGBoostExecutionParams xGBoostExecutionParams, boolean z) {
        int gpuId = GpuUtils$.MODULE$.getGpuId(z);
        logger().info(new StringBuilder(35).append("XGboost GPU training using device: ").append(gpuId).toString());
        xGBoostExecutionParams.setRawParamMap(xGBoostExecutionParams.toMap().$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("gpu_id"), BoxesRunTime.boxToInteger(gpuId).toString())));
        return xGBoostExecutionParams;
    }

    private Tuple2<DMatrix, Object> buildDMatrixIncrementally(Iterator<GpuColumnBatch> iterator, ColumnIndices columnIndices, float f, boolean z, int i) {
        ColumnDMatrix columnDMatrix = null;
        double MinValue = Double$.MODULE$.MinValue();
        ArrayBuffer arrayBuffer = new ArrayBuffer();
        ArrayBuffer arrayBuffer2 = new ArrayBuffer();
        int i2 = 0;
        while (iterator.hasNext()) {
            GpuColumnBatch gpuColumnBatch = (GpuColumnBatch) iterator.next();
            if (0 != 0) {
                i2 = gpuColumnBatch.groupAndAggregateOnColumnsHost(BoxesRunTime.unboxToInt(columnIndices.groupId().get()), BoxesRunTime.unboxToInt(columnIndices.weightId().getOrElse(() -> {
                    return -1;
                })), i2, (java.util.List) JavaConverters$.MODULE$.bufferAsJavaListConverter(arrayBuffer).asJava(), (java.util.List) JavaConverters$.MODULE$.bufferAsJavaListConverter(arrayBuffer2).asJava());
            }
            String arrayInterface = gpuColumnBatch.getArrayInterface((int[]) columnIndices.featureIds().toArray(ClassTag$.MODULE$.Int()));
            String arrayInterface2 = gpuColumnBatch.getArrayInterface(new int[]{columnIndices.labelId()});
            Option map = columnIndices.weightId().map(obj -> {
                return $anonfun$buildDMatrixIncrementally$2(gpuColumnBatch, BoxesRunTime.unboxToInt(obj));
            });
            Option map2 = columnIndices.marginId().map(obj2 -> {
                return $anonfun$buildDMatrixIncrementally$3(gpuColumnBatch, BoxesRunTime.unboxToInt(obj2));
            });
            if (columnDMatrix == null) {
                columnDMatrix = new ColumnDMatrix(arrayInterface, f, 1);
                columnDMatrix.setLabel(arrayInterface2);
                map.foreach(str -> {
                    columnDMatrix.setWeight(str);
                    return BoxedUnit.UNIT;
                });
                map2.foreach(str2 -> {
                    columnDMatrix.setBaseMargin(str2);
                    return BoxedUnit.UNIT;
                });
            } else {
                logger().warn("Incremental building DMatrix is not supported now !");
            }
            if (z) {
                double maxInColumn = gpuColumnBatch.getMaxInColumn(columnIndices.labelId());
                MinValue = maxInColumn > MinValue ? maxInColumn : MinValue;
            }
            gpuColumnBatch.close();
        }
        logger().debug(new StringBuilder(11).append("Num class: ").append(MinValue).toString());
        if (columnDMatrix != null && 0 != 0) {
            logger().info("Learning to rank.");
            columnDMatrix.setGroup((int[]) ((TraversableOnce) arrayBuffer.map(num -> {
                return BoxesRunTime.boxToInteger(num.intValue());
            }, ArrayBuffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Int()));
            if (arrayBuffer2.nonEmpty()) {
                columnDMatrix.setWeight((float[]) ((TraversableOnce) arrayBuffer2.map(f2 -> {
                    return BoxesRunTime.boxToFloat(f2.floatValue());
                }, ArrayBuffer$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float()));
            }
        }
        return new Tuple2<>(columnDMatrix, BoxesRunTime.boxToDouble(MinValue));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Tuple2<DMatrix, Object> buildDMatrix(Iterator<GpuColumnBatch> iterator, ColumnIndices columnIndices, float f, boolean z, int i) {
        GpuXGBoost.RapidsIterator rapidsIterator = new GpuXGBoost.RapidsIterator(iterator, z, columnIndices);
        return new Tuple2<>(new ColumnDMatrix(rapidsIterator, f, i, 1), BoxesRunTime.boxToDouble(rapidsIterator.maxLabels()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Watches buildWatches(Option<String> option, float f, ColumnIndices columnIndices, Iterator<GpuColumnBatch> iterator, LongAccumulator longAccumulator, int i) {
        Tuple2 time = MLUtils$.MODULE$.time(() -> {
            return MODULE$.buildDMatrix(iterator, columnIndices, f, longAccumulator != null, i);
        });
        if (time != null) {
            Tuple2 tuple2 = (Tuple2) time._1();
            float unboxToFloat = BoxesRunTime.unboxToFloat(time._2());
            if (tuple2 != null) {
                Tuple3 tuple3 = new Tuple3((DMatrix) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()), BoxesRunTime.boxToFloat(unboxToFloat));
                DMatrix dMatrix = (DMatrix) tuple3._1();
                double unboxToDouble = BoxesRunTime.unboxToDouble(tuple3._2());
                logger().debug(new StringBuilder(46).append("Benchmark[Train: Build DMatrix incrementally] ").append(BoxesRunTime.unboxToFloat(tuple3._3())).toString());
                Tuple2 tuple22 = dMatrix == null ? new Tuple2(Array$.MODULE$.empty(ClassTag$.MODULE$.apply(DMatrix.class)), Array$.MODULE$.empty(ClassTag$.MODULE$.apply(String.class))) : new Tuple2(new DMatrix[]{dMatrix}, new String[]{"train"});
                if (tuple22 == null) {
                    throw new MatchError(tuple22);
                }
                Tuple2 tuple23 = new Tuple2((DMatrix[]) tuple22._1(), (String[]) tuple22._2());
                return new GpuWatches((DMatrix[]) tuple23._1(), (String[]) tuple23._2(), option, longAccumulator, unboxToDouble);
            }
        }
        throw new MatchError(time);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Watches buildWatchesWithEval(Option<String> option, float f, Map<String, ColumnIndices> map, Iterator<Tuple2<String, Iterator<GpuColumnBatch>>> iterator, LongAccumulator longAccumulator, int i) {
        DoubleRef create = DoubleRef.create(0.0d);
        Tuple2[] tuple2Arr = (Tuple2[]) iterator.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            Iterator iterator2 = (Iterator) tuple2._2();
            boolean z = longAccumulator != null && (str != null ? str.equals("train") : "train" == 0);
            Tuple2 time = MLUtils$.MODULE$.time(() -> {
                return MODULE$.buildDMatrix(iterator2, (ColumnIndices) map.apply(str), f, z, i);
            });
            if (time != null) {
                Tuple2 tuple2 = (Tuple2) time._1();
                float unboxToFloat = BoxesRunTime.unboxToFloat(time._2());
                if (tuple2 != null) {
                    Tuple3 tuple3 = new Tuple3((DMatrix) tuple2._1(), BoxesRunTime.boxToDouble(tuple2._2$mcD$sp()), BoxesRunTime.boxToFloat(unboxToFloat));
                    DMatrix dMatrix = (DMatrix) tuple3._1();
                    double unboxToDouble = BoxesRunTime.unboxToDouble(tuple3._2());
                    MODULE$.logger().debug(new StringBuilder(32).append("Benchmark[Train build ").append(str).append(" DMatrix] ").append(BoxesRunTime.unboxToFloat(tuple3._3())).toString());
                    if (z) {
                        create.elem = unboxToDouble;
                    }
                    return new Tuple2(str, dMatrix);
                }
            }
            throw new MatchError(time);
        }).filter(tuple22 -> {
            return BoxesRunTime.boxToBoolean($anonfun$buildWatchesWithEval$3(tuple22));
        }).toArray(ClassTag$.MODULE$.apply(Tuple2.class));
        return new GpuWatches((DMatrix[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).map(tuple23 -> {
            return (DMatrix) tuple23._2();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(DMatrix.class))), (String[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).map(tuple24 -> {
            return (String) tuple24._1();
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class))), option, longAccumulator, create.elem);
    }

    public void checkNumClass(Watches watches, Map<String, Object> map) {
        BoxedUnit boxedUnit;
        if (!(watches instanceof GpuWatches)) {
            BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            return;
        }
        GpuWatches gpuWatches = (GpuWatches) watches;
        if (gpuWatches.accNumClass() != null) {
            float[] allReduce = Rabit.allReduce(new float[]{(float) gpuWatches.numClass()}, Rabit.OpType.MAX);
            Predef$.MODULE$.require(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(allReduce)).nonEmpty(), () -> {
                return "Failed to infer class number.";
            });
            float unboxToFloat = BoxesRunTime.unboxToFloat(new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(allReduce)).head());
            Predef$.MODULE$.require(RichFloat$.MODULE$.isValidInt$extension(Predef$.MODULE$.floatWrapper(unboxToFloat)), () -> {
                return new StringBuilder(34).append("Classifier found max label value =").append(new StringBuilder(42).append(" ").append(unboxToFloat).append(" but requires integers in range [0, ... ").append(Integer.MAX_VALUE).append(")").toString()).toString();
            });
            int i = ((int) unboxToFloat) + 1;
            if (map.contains("num_class")) {
                Predef$.MODULE$.require(BoxesRunTime.equals(map.apply("num_class"), BoxesRunTime.boxToInteger(i)), () -> {
                    return "The number of classes in Dataset doesn't match 'num_class' in parameters.";
                });
            }
            if (Rabit.getRank() == 0) {
                gpuWatches.accNumClass().add(i);
                boxedUnit = BoxedUnit.UNIT;
            } else {
                boxedUnit = BoxedUnit.UNIT;
            }
        } else {
            boxedUnit = BoxedUnit.UNIT;
        }
    }

    public static final /* synthetic */ void $anonfun$trainDistributedOnGpu$3(XGBoostExecutionParams xGBoostExecutionParams, SparkContext sparkContext, ExternalCheckpointParams externalCheckpointParams) {
        if (((ExternalCheckpointParams) xGBoostExecutionParams.checkpointParam().get()).skipCleanCheckpoint()) {
            return;
        }
        new ExternalCheckpointManager(externalCheckpointParams.checkpointPath(), FileSystem.get(sparkContext.hadoopConfiguration())).cleanPath();
    }

    public static final /* synthetic */ float[] $anonfun$buildDistributedBooster$1(int i, int i2) {
        return (float[]) Array$.MODULE$.ofDim(i, ClassTag$.MODULE$.Float());
    }

    public static final /* synthetic */ boolean $anonfun$coPartitionForGpu$5(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public static final /* synthetic */ String $anonfun$buildDMatrixIncrementally$2(GpuColumnBatch gpuColumnBatch, int i) {
        return gpuColumnBatch.getArrayInterface(new int[]{i});
    }

    public static final /* synthetic */ String $anonfun$buildDMatrixIncrementally$3(GpuColumnBatch gpuColumnBatch, int i) {
        return gpuColumnBatch.getArrayInterface(new int[]{i});
    }

    public static final /* synthetic */ boolean $anonfun$buildWatchesWithEval$3(Tuple2 tuple2) {
        return tuple2._2() != null;
    }

    private GpuXGBoost$() {
        MODULE$ = this;
        this.trainName = "train";
        this.logger = LogFactory.getLog("GpuXGBoostSpark");
    }
}
