package ml.dmlc.xgboost4j.scala.spark;

import ai.rapids.cudf.Table;
import java.lang.Thread;
import java.nio.file.Files;
import java.nio.file.attribute.FileAttribute;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker;
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker$;
import ml.dmlc.xgboost4j.scala.spark.XGBoost;
import ml.dmlc.xgboost4j.scala.spark.params.BoosterParams$;
import ml.dmlc.xgboost4j.scala.spark.rapids.GpuDeviceManager$;
import ml.dmlc.xgboost4j.scala.spark.rapids.GpuSampler;
import ml.dmlc.xgboost4j.scala.spark.rapids.PluginUtils$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkParallelismTracker;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.catalyst.encoders.RowEncoder$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.PartialFunction;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Some;
import scala.Tuple10;
import scala.Tuple2;
import scala.Tuple5;
import scala.Tuple6;
import scala.collection.BufferedIterator;
import scala.collection.GenTraversableOnce;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.Traversable;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.CanBuildFrom;
import scala.collection.immutable.IndexedSeq;
import scala.collection.immutable.List;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.immutable.Stream;
import scala.collection.immutable.StringOps;
import scala.collection.immutable.Vector;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.Buffer;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.Nothing$;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;
import scala.util.Either;

/* compiled from: XGBoost.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/XGBoost$.class */
public final class XGBoost$ implements Serializable {
    public static XGBoost$ MODULE$;
    private final Log logger;
    private final String trainName;

    static {
        new XGBoost$();
    }

    private Log logger() {
        return this.logger;
    }

    private String trainName() {
        return this.trainName;
    }

    private Iterator<LabeledPoint> verifyMissingSetting(Iterator<LabeledPoint> iterator, float f) {
        return f != 0.0f ? iterator.map(labeledPoint -> {
            if (labeledPoint.indices() != null) {
                throw new RuntimeException(new StringBuilder(63).append("you can only specify missing value as 0.0 (the currently").append(new StringBuilder(71).append(" set value ").append(f).append(") when you have SparseVector or Empty vector as your feature").toString()).append(" format").toString());
            }
            return labeledPoint;
        }) : iterator;
    }

    private Iterator<LabeledPoint> removeMissingValues(Iterator<LabeledPoint> iterator, float f, Function1<Object, Object> function1) {
        return iterator.map(labeledPoint -> {
            ArrayBuilder.ofInt ofint = new ArrayBuilder.ofInt();
            ArrayBuilder.ofFloat offloat = new ArrayBuilder.ofFloat();
            new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofFloat(Predef$.MODULE$.floatArrayOps(labeledPoint.values())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).withFilter(tuple2 -> {
                return BoxesRunTime.boxToBoolean($anonfun$removeMissingValues$2(tuple2));
            }).withFilter(tuple22 -> {
                return BoxesRunTime.boxToBoolean($anonfun$removeMissingValues$3(function1, tuple22));
            }).foreach(tuple23 -> {
                if (tuple23 == null) {
                    throw new MatchError(tuple23);
                }
                float unboxToFloat = BoxesRunTime.unboxToFloat(tuple23._1());
                int _2$mcI$sp = tuple23._2$mcI$sp();
                ofint.$plus$eq(labeledPoint.indices() == null ? _2$mcI$sp : labeledPoint.indices()[_2$mcI$sp]);
                return offloat.$plus$eq(unboxToFloat);
            });
            return labeledPoint.copy(labeledPoint.copy$default$1(), ofint.result(), offloat.result(), labeledPoint.copy$default$4(), labeledPoint.copy$default$5(), labeledPoint.copy$default$6());
        });
    }

    public Iterator<LabeledPoint> processMissingValues(Iterator<LabeledPoint> iterator, float f) {
        return !Predef$.MODULE$.float2Float(f).isNaN() ? removeMissingValues(verifyMissingSetting(iterator, f), f, f2 -> {
            return f2 != f;
        }) : removeMissingValues(verifyMissingSetting(iterator, f), f, f3 -> {
            return !Predef$.MODULE$.float2Float(f3).isNaN();
        });
    }

    private Iterator<LabeledPoint[]> processMissingValuesWithGroup(Iterator<LabeledPoint[]> iterator, float f) {
        return !Predef$.MODULE$.float2Float(f).isNaN() ? iterator.map(labeledPointArr -> {
            return (LabeledPoint[]) MODULE$.processMissingValues(new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(labeledPointArr)).iterator(), f).toArray(ClassTag$.MODULE$.apply(LabeledPoint.class));
        }) : iterator;
    }

    private Option<String> getCacheDirName(boolean z) {
        return z ? new Some(Files.createTempDirectory(new StringBuilder(7).append(TaskContext$.MODULE$.get().stageId()).append("-cache-").append(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()).toString()).toString(), new FileAttribute[0]).toAbsolutePath().toString()) : None$.MODULE$;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Code restructure failed: missing block: B:43:0x0249, code lost:
    
        if (r0.equals("gpu_hist") != false) goto L42;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public scala.collection.Iterator<scala.Tuple2<ml.dmlc.xgboost4j.scala.Booster, scala.collection.immutable.Map<java.lang.String, float[]>>> buildDistributedBooster(ml.dmlc.xgboost4j.scala.spark.Watches r12, scala.collection.immutable.Map<java.lang.String, java.lang.Object> r13, java.util.Map<java.lang.String, java.lang.String> r14, int r15, ml.dmlc.xgboost4j.scala.ObjectiveTrait r16, ml.dmlc.xgboost4j.scala.EvalTrait r17, ml.dmlc.xgboost4j.scala.Booster r18) {
        /*
            Method dump skipped, instructions count: 758
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: ml.dmlc.xgboost4j.scala.spark.XGBoost$.buildDistributedBooster(ml.dmlc.xgboost4j.scala.spark.Watches, scala.collection.immutable.Map, java.util.Map, int, ml.dmlc.xgboost4j.scala.ObjectiveTrait, ml.dmlc.xgboost4j.scala.EvalTrait, ml.dmlc.xgboost4j.scala.Booster):scala.collection.Iterator");
    }

    private Map<String, Object> overrideParamsAccordingToTaskCPUs(Map<String, Object> map, SparkContext sparkContext) {
        int i = sparkContext.getConf().getInt("spark.task.cpus", 1);
        Map<String, Object> map2 = map;
        if (map2.contains("nthread")) {
            int i2 = new StringOps(Predef$.MODULE$.augmentString(map2.apply("nthread").toString())).toInt();
            Predef$.MODULE$.require(i2 <= i, () -> {
                return new StringBuilder(52).append("the nthread configuration (").append(i2).append(") must be no larger than ").append(new StringBuilder(18).append("spark.task.cpus (").append(i).append(")").toString()).toString();
            });
        } else {
            map2 = map.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("nthread"), BoxesRunTime.boxToInteger(i)));
        }
        return map2;
    }

    private IRabitTracker startTracker(int i, TrackerConf trackerConf) {
        String trackerImpl = trackerConf.trackerImpl();
        RabitTracker rabitTracker = "scala".equals(trackerImpl) ? new RabitTracker(i, RabitTracker$.MODULE$.$lessinit$greater$default$2(), RabitTracker$.MODULE$.$lessinit$greater$default$3()) : "python".equals(trackerImpl) ? new ml.dmlc.xgboost4j.java.RabitTracker(i) : new ml.dmlc.xgboost4j.java.RabitTracker(i);
        Predef$.MODULE$.require(rabitTracker.start(trackerConf.workerConnectionTimeout()), () -> {
            return "FAULT: Failed to start tracker";
        });
        return rabitTracker;
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets(RDD<LabeledPoint> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) ((Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD rdd2 = (RDD) tuple2._2();
            return rdd2.getNumPartitions() != i ? new Tuple2(str, rdd2.repartition(i, rdd2.repartition$default$2(i))) : new Tuple2(str, rdd2);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new XGBoost.IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new XGBoost.IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(LabeledPoint.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    private void validateSparkSslConf(SparkContext sparkContext) {
        Tuple2.mcZZ.sp spVar;
        Some activeSession = SparkSession$.MODULE$.getActiveSession();
        if (activeSession instanceof Some) {
            SparkSession sparkSession = (SparkSession) activeSession.value();
            spVar = new Tuple2.mcZZ.sp(new StringOps(Predef$.MODULE$.augmentString((String) sparkSession.conf().getOption("spark.ssl.enabled").getOrElse(() -> {
                return "false";
            }))).toBoolean(), new StringOps(Predef$.MODULE$.augmentString((String) sparkSession.conf().getOption("xgboost.spark.ignoreSsl").getOrElse(() -> {
                return "false";
            }))).toBoolean());
        } else {
            if (!None$.MODULE$.equals(activeSession)) {
                throw new MatchError(activeSession);
            }
            spVar = new Tuple2.mcZZ.sp(sparkContext.getConf().getBoolean("spark.ssl.enabled", false), sparkContext.getConf().getBoolean("xgboost.spark.ignoreSsl", false));
        }
        Tuple2.mcZZ.sp spVar2 = spVar;
        if (spVar2 == null) {
            throw new MatchError(spVar2);
        }
        Tuple2.mcZZ.sp spVar3 = new Tuple2.mcZZ.sp(spVar2._1$mcZ$sp(), spVar2._2$mcZ$sp());
        boolean _1$mcZ$sp = spVar3._1$mcZ$sp();
        boolean _2$mcZ$sp = spVar3._2$mcZ$sp();
        if (_1$mcZ$sp) {
            if (!_2$mcZ$sp) {
                throw new Exception("xgboost-spark found spark.ssl.enabled=true to encrypt data in transit, but xgboost-spark sends non-encrypted data over the wire for efficiency. To override this protection and still use xgboost-spark at your own risk, you can set the SparkSession conf to use xgboost.spark.ignoreSsl=true.");
            }
            logger().warn(new StringBuilder(147).append("spark-xgboost is being run without encrypting data in transit!  ").append("Spark Conf spark.ssl.enabled=true was overridden with xgboost.spark.ignoreSsl=true.").toString());
        }
    }

    private Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation(Map<String, Object> map, SparkContext sparkContext) {
        TrackerConf trackerConf;
        long unboxToLong;
        int unboxToInt = BoxesRunTime.unboxToInt(map.apply("num_workers"));
        int unboxToInt2 = BoxesRunTime.unboxToInt(map.apply("num_round"));
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(map.apply("use_external_memory"));
        ObjectiveTrait objectiveTrait = (ObjectiveTrait) map.getOrElse("custom_obj", () -> {
            return null;
        });
        EvalTrait evalTrait = (EvalTrait) map.getOrElse("custom_eval", () -> {
            return null;
        });
        float unboxToFloat = BoxesRunTime.unboxToFloat(map.getOrElse("missing", () -> {
            return Float.NaN;
        }));
        validateSparkSslConf(sparkContext);
        if (map.contains("tree_method")) {
            Predef$.MODULE$.require(BoosterParams$.MODULE$.supportedTreeMethods().contains((String) map.apply("tree_method")), () -> {
                return new StringBuilder(46).append("xgboost4j-spark only supports tree_method as [").append(new StringBuilder(1).append(BoosterParams$.MODULE$.supportedTreeMethods().mkString(", ")).append("]").toString()).toString();
            });
        }
        if (map.contains("train_test_ratio")) {
            logger().warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly pass a training and multiple evaluation datasets by passing 'eval_sets' as Map('name'->'Dataset')");
        }
        Predef$.MODULE$.require(unboxToInt > 0, () -> {
            return "you must specify more than 0 workers";
        });
        if (objectiveTrait != null) {
            Predef$.MODULE$.require(map.get("objective_type").isDefined(), () -> {
                return "parameter \"objective_type\" is not defined, you have to specify the objective type as classification or regression with a customized objective function";
            });
        }
        Some some = map.get("tracker_conf");
        if (!None$.MODULE$.equals(some)) {
            if (some instanceof Some) {
                Object value = some.value();
                if (value instanceof TrackerConf) {
                    trackerConf = (TrackerConf) value;
                }
            }
            throw new IllegalArgumentException("parameter \"tracker_conf\" must be an instance of TrackerConf.");
        }
        trackerConf = TrackerConf$.MODULE$.apply();
        TrackerConf trackerConf2 = trackerConf;
        Some some2 = map.get("timeout_request_workers");
        if (!None$.MODULE$.equals(some2)) {
            if (some2 instanceof Some) {
                Object value2 = some2.value();
                if (value2 instanceof Long) {
                    unboxToLong = BoxesRunTime.unboxToLong(value2);
                }
            }
            throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be an instance of Long.");
        }
        unboxToLong = 0;
        long j = unboxToLong;
        Tuple2<String, Object> extractParams = CheckpointManager$.MODULE$.extractParams(map);
        if (extractParams == null) {
            throw new MatchError(extractParams);
        }
        Tuple2 tuple2 = new Tuple2((String) extractParams._1(), BoxesRunTime.boxToInteger(extractParams._2$mcI$sp()));
        return new Tuple10<>(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToInteger(unboxToInt2), BoxesRunTime.boxToBoolean(unboxToBoolean), objectiveTrait, evalTrait, BoxesRunTime.boxToFloat(unboxToFloat), trackerConf2, BoxesRunTime.boxToLong(j), (String) tuple2._1(), BoxesRunTime.boxToInteger(tuple2._2$mcI$sp()));
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForNonRanking(RDD<LabeledPoint> rdd, Map<String, Object> map, java.util.Map<String, String> map2, int i, Booster booster, Map<String, RDD<LabeledPoint>> map3) {
        Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation = parameterFetchAndValidation(map, rdd.sparkContext());
        if (parameterFetchAndValidation == null) {
            throw new MatchError(parameterFetchAndValidation);
        }
        int unboxToInt = BoxesRunTime.unboxToInt(parameterFetchAndValidation._1());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(parameterFetchAndValidation._3());
        Tuple5 tuple5 = new Tuple5(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToBoolean(unboxToBoolean), (ObjectiveTrait) parameterFetchAndValidation._4(), (EvalTrait) parameterFetchAndValidation._5(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(parameterFetchAndValidation._6())));
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple5._1());
        boolean unboxToBoolean2 = BoxesRunTime.unboxToBoolean(tuple5._2());
        ObjectiveTrait objectiveTrait = (ObjectiveTrait) tuple5._3();
        EvalTrait evalTrait = (EvalTrait) tuple5._4();
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple5._5());
        if (map3.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatches(map, MODULE$.processMissingValues(iterator, unboxToFloat), MODULE$.getCacheDirName(unboxToBoolean2)), map, map2, i, objectiveTrait, evalTrait, booster);
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint>>> coPartitionNoGroupSets = coPartitionNoGroupSets(rdd, map3, unboxToInt2);
        return coPartitionNoGroupSets.mapPartitions(iterator2 -> {
            return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatches(iterator2.map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                return new Tuple2((String) tuple2._1(), MODULE$.processMissingValues((Iterator) tuple2._2(), unboxToFloat));
            }), MODULE$.getCacheDirName(unboxToBoolean2)), map, map2, i, objectiveTrait, evalTrait, booster);
        }, coPartitionNoGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainForRanking(RDD<LabeledPoint[]> rdd, Map<String, Object> map, java.util.Map<String, String> map2, int i, Booster booster, Map<String, RDD<LabeledPoint>> map3) {
        Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation = parameterFetchAndValidation(map, rdd.sparkContext());
        if (parameterFetchAndValidation == null) {
            throw new MatchError(parameterFetchAndValidation);
        }
        int unboxToInt = BoxesRunTime.unboxToInt(parameterFetchAndValidation._1());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(parameterFetchAndValidation._3());
        Tuple5 tuple5 = new Tuple5(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToBoolean(unboxToBoolean), (ObjectiveTrait) parameterFetchAndValidation._4(), (EvalTrait) parameterFetchAndValidation._5(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(parameterFetchAndValidation._6())));
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple5._1());
        boolean unboxToBoolean2 = BoxesRunTime.unboxToBoolean(tuple5._2());
        ObjectiveTrait objectiveTrait = (ObjectiveTrait) tuple5._3();
        EvalTrait evalTrait = (EvalTrait) tuple5._4();
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple5._5());
        if (map3.isEmpty()) {
            return rdd.mapPartitions(iterator -> {
                return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatchesWithGroup(map, MODULE$.processMissingValuesWithGroup(iterator, unboxToFloat), MODULE$.getCacheDirName(unboxToBoolean2)), map, map2, i, objectiveTrait, evalTrait, booster);
            }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets = coPartitionGroupSets(rdd, map3, unboxToInt2);
        return coPartitionGroupSets.mapPartitions(iterator2 -> {
            return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatchesWithGroup(iterator2.map(tuple2 -> {
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                return new Tuple2((String) tuple2._1(), MODULE$.processMissingValuesWithGroup((Iterator) tuple2._2(), unboxToFloat));
            }), MODULE$.getCacheDirName(unboxToBoolean2)), map, map2, i, objectiveTrait, evalTrait, booster);
        }, coPartitionGroupSets.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private RDD<?> cacheData(boolean z, RDD<?> rdd) {
        return z ? rdd.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : rdd;
    }

    private Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData(RDD<LabeledPoint> rdd, boolean z, boolean z2, int i) {
        if (z2) {
            return scala.package$.MODULE$.Left().apply(cacheData(z, repartitionForTrainingGroup(rdd, i)));
        }
        return scala.package$.MODULE$.Right().apply(cacheData(z, repartitionForTraining(rdd, i)));
    }

    private Map<String, Object> parameterOverrideToUseGPU(Map<String, Object> map) {
        Map<String, Object> map2 = map;
        if (map2.contains("tree_method")) {
            String str = (String) map2.apply("tree_method");
            if (str != null ? !str.equals("auto") : "auto" != 0) {
                Predef$.MODULE$.require(str.startsWith("gpu_"), () -> {
                    return new StringBuilder(68).append("Now for training on GPU, xgboost-spark only supports tree_method as ").append(new StringBuilder(2).append("[").append(BoosterParams$.MODULE$.supportedTreeMethods().filter(str2 -> {
                        return BoxesRunTime.boxToBoolean(str2.startsWith("gpu_"));
                    }).mkString(", ")).append("]").toString()).append(new StringBuilder(14).append(", but found '").append(str).append("'").toString()).toString();
                });
            } else {
                map2 = map2.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("tree_method"), "gpu_hist"));
            }
        } else {
            map2 = map2.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("tree_method"), "gpu_hist"));
        }
        return map2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Dataset<Row> repartitionForGroup(String str, Dataset<Row> dataset, int i) {
        logger().info("LTR start groupBy");
        StructType schema = dataset.schema();
        Dataset agg = dataset.groupBy(str, Predef$.MODULE$.wrapRefArray(new String[0])).agg(functions$.MODULE$.collect_list(functions$.MODULE$.struct(Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(schema.fieldNames())).map(str2 -> {
            return functions$.MODULE$.col(str2);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))))).as("list"), Predef$.MODULE$.wrapRefArray(new Column[0]));
        return agg.repartition(i).mapPartitions(iterator -> {
            return new Iterator<Row>(iterator) { // from class: ml.dmlc.xgboost4j.scala.spark.XGBoost$$anon$1
                private Iterator<Object> iterInRow;
                private final Iterator iter$1;

                /* renamed from: seq, reason: merged with bridge method [inline-methods] */
                public Iterator<Row> m75seq() {
                    return Iterator.seq$(this);
                }

                public boolean isEmpty() {
                    return Iterator.isEmpty$(this);
                }

                public boolean isTraversableAgain() {
                    return Iterator.isTraversableAgain$(this);
                }

                public boolean hasDefiniteSize() {
                    return Iterator.hasDefiniteSize$(this);
                }

                public Iterator<Row> take(int i2) {
                    return Iterator.take$(this, i2);
                }

                public Iterator<Row> drop(int i2) {
                    return Iterator.drop$(this, i2);
                }

                public Iterator<Row> slice(int i2, int i3) {
                    return Iterator.slice$(this, i2, i3);
                }

                public Iterator<Row> sliceIterator(int i2, int i3) {
                    return Iterator.sliceIterator$(this, i2, i3);
                }

                public <B> Iterator<B> map(Function1<Row, B> function1) {
                    return Iterator.map$(this, function1);
                }

                public <B> Iterator<B> $plus$plus(Function0<GenTraversableOnce<B>> function0) {
                    return Iterator.$plus$plus$(this, function0);
                }

                public <B> Iterator<B> flatMap(Function1<Row, GenTraversableOnce<B>> function1) {
                    return Iterator.flatMap$(this, function1);
                }

                public Iterator<Row> filter(Function1<Row, Object> function1) {
                    return Iterator.filter$(this, function1);
                }

                public <B> boolean corresponds(GenTraversableOnce<B> genTraversableOnce, Function2<Row, B, Object> function2) {
                    return Iterator.corresponds$(this, genTraversableOnce, function2);
                }

                public Iterator<Row> withFilter(Function1<Row, Object> function1) {
                    return Iterator.withFilter$(this, function1);
                }

                public Iterator<Row> filterNot(Function1<Row, Object> function1) {
                    return Iterator.filterNot$(this, function1);
                }

                public <B> Iterator<B> collect(PartialFunction<Row, B> partialFunction) {
                    return Iterator.collect$(this, partialFunction);
                }

                public <B> Iterator<B> scanLeft(B b, Function2<B, Row, B> function2) {
                    return Iterator.scanLeft$(this, b, function2);
                }

                public <B> Iterator<B> scanRight(B b, Function2<Row, B, B> function2) {
                    return Iterator.scanRight$(this, b, function2);
                }

                public Iterator<Row> takeWhile(Function1<Row, Object> function1) {
                    return Iterator.takeWhile$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> partition(Function1<Row, Object> function1) {
                    return Iterator.partition$(this, function1);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> span(Function1<Row, Object> function1) {
                    return Iterator.span$(this, function1);
                }

                public Iterator<Row> dropWhile(Function1<Row, Object> function1) {
                    return Iterator.dropWhile$(this, function1);
                }

                public <B> Iterator<Tuple2<Row, B>> zip(Iterator<B> iterator) {
                    return Iterator.zip$(this, iterator);
                }

                public <A1> Iterator<A1> padTo(int i2, A1 a1) {
                    return Iterator.padTo$(this, i2, a1);
                }

                public Iterator<Tuple2<Row, Object>> zipWithIndex() {
                    return Iterator.zipWithIndex$(this);
                }

                public <B, A1, B1> Iterator<Tuple2<A1, B1>> zipAll(Iterator<B> iterator, A1 a1, B1 b1) {
                    return Iterator.zipAll$(this, iterator, a1, b1);
                }

                public <U> void foreach(Function1<Row, U> function1) {
                    Iterator.foreach$(this, function1);
                }

                public boolean forall(Function1<Row, Object> function1) {
                    return Iterator.forall$(this, function1);
                }

                public boolean exists(Function1<Row, Object> function1) {
                    return Iterator.exists$(this, function1);
                }

                public boolean contains(Object obj) {
                    return Iterator.contains$(this, obj);
                }

                public Option<Row> find(Function1<Row, Object> function1) {
                    return Iterator.find$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1) {
                    return Iterator.indexWhere$(this, function1);
                }

                public int indexWhere(Function1<Row, Object> function1, int i2) {
                    return Iterator.indexWhere$(this, function1, i2);
                }

                public <B> int indexOf(B b) {
                    return Iterator.indexOf$(this, b);
                }

                public <B> int indexOf(B b, int i2) {
                    return Iterator.indexOf$(this, b, i2);
                }

                public BufferedIterator<Row> buffered() {
                    return Iterator.buffered$(this);
                }

                public <B> Iterator<Row>.GroupedIterator<B> grouped(int i2) {
                    return Iterator.grouped$(this, i2);
                }

                public <B> Iterator<Row>.GroupedIterator<B> sliding(int i2, int i3) {
                    return Iterator.sliding$(this, i2, i3);
                }

                public <B> int sliding$default$2() {
                    return Iterator.sliding$default$2$(this);
                }

                public int length() {
                    return Iterator.length$(this);
                }

                public Tuple2<Iterator<Row>, Iterator<Row>> duplicate() {
                    return Iterator.duplicate$(this);
                }

                public <B> Iterator<B> patch(int i2, Iterator<B> iterator, int i3) {
                    return Iterator.patch$(this, i2, iterator, i3);
                }

                public <B> void copyToArray(Object obj, int i2, int i3) {
                    Iterator.copyToArray$(this, obj, i2, i3);
                }

                public boolean sameElements(Iterator<?> iterator) {
                    return Iterator.sameElements$(this, iterator);
                }

                /* renamed from: toTraversable, reason: merged with bridge method [inline-methods] */
                public Traversable<Row> m74toTraversable() {
                    return Iterator.toTraversable$(this);
                }

                public Iterator<Row> toIterator() {
                    return Iterator.toIterator$(this);
                }

                public Stream<Row> toStream() {
                    return Iterator.toStream$(this);
                }

                public String toString() {
                    return Iterator.toString$(this);
                }

                public List<Row> reversed() {
                    return TraversableOnce.reversed$(this);
                }

                public int size() {
                    return TraversableOnce.size$(this);
                }

                public boolean nonEmpty() {
                    return TraversableOnce.nonEmpty$(this);
                }

                public int count(Function1<Row, Object> function1) {
                    return TraversableOnce.count$(this, function1);
                }

                public <B> Option<B> collectFirst(PartialFunction<Row, B> partialFunction) {
                    return TraversableOnce.collectFirst$(this, partialFunction);
                }

                public <B> B $div$colon(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.$div$colon$(this, b, function2);
                }

                public <B> B $colon$bslash(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.$colon$bslash$(this, b, function2);
                }

                public <B> B foldLeft(B b, Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.foldLeft$(this, b, function2);
                }

                public <B> B foldRight(B b, Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.foldRight$(this, b, function2);
                }

                public <B> B reduceLeft(Function2<B, Row, B> function2) {
                    return (B) TraversableOnce.reduceLeft$(this, function2);
                }

                public <B> B reduceRight(Function2<Row, B, B> function2) {
                    return (B) TraversableOnce.reduceRight$(this, function2);
                }

                public <B> Option<B> reduceLeftOption(Function2<B, Row, B> function2) {
                    return TraversableOnce.reduceLeftOption$(this, function2);
                }

                public <B> Option<B> reduceRightOption(Function2<Row, B, B> function2) {
                    return TraversableOnce.reduceRightOption$(this, function2);
                }

                public <A1> A1 reduce(Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.reduce$(this, function2);
                }

                public <A1> Option<A1> reduceOption(Function2<A1, A1, A1> function2) {
                    return TraversableOnce.reduceOption$(this, function2);
                }

                public <A1> A1 fold(A1 a1, Function2<A1, A1, A1> function2) {
                    return (A1) TraversableOnce.fold$(this, a1, function2);
                }

                public <B> B aggregate(Function0<B> function0, Function2<B, Row, B> function2, Function2<B, B, B> function22) {
                    return (B) TraversableOnce.aggregate$(this, function0, function2, function22);
                }

                public <B> B sum(Numeric<B> numeric) {
                    return (B) TraversableOnce.sum$(this, numeric);
                }

                public <B> B product(Numeric<B> numeric) {
                    return (B) TraversableOnce.product$(this, numeric);
                }

                public Object min(Ordering ordering) {
                    return TraversableOnce.min$(this, ordering);
                }

                public Object max(Ordering ordering) {
                    return TraversableOnce.max$(this, ordering);
                }

                public Object maxBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.maxBy$(this, function1, ordering);
                }

                public Object minBy(Function1 function1, Ordering ordering) {
                    return TraversableOnce.minBy$(this, function1, ordering);
                }

                public <B> void copyToBuffer(Buffer<B> buffer) {
                    TraversableOnce.copyToBuffer$(this, buffer);
                }

                public <B> void copyToArray(Object obj, int i2) {
                    TraversableOnce.copyToArray$(this, obj, i2);
                }

                public <B> void copyToArray(Object obj) {
                    TraversableOnce.copyToArray$(this, obj);
                }

                public <B> Object toArray(ClassTag<B> classTag) {
                    return TraversableOnce.toArray$(this, classTag);
                }

                public List<Row> toList() {
                    return TraversableOnce.toList$(this);
                }

                /* renamed from: toIterable, reason: merged with bridge method [inline-methods] */
                public Iterable<Row> m73toIterable() {
                    return TraversableOnce.toIterable$(this);
                }

                /* renamed from: toSeq, reason: merged with bridge method [inline-methods] */
                public Seq<Row> m72toSeq() {
                    return TraversableOnce.toSeq$(this);
                }

                public IndexedSeq<Row> toIndexedSeq() {
                    return TraversableOnce.toIndexedSeq$(this);
                }

                public <B> Buffer<B> toBuffer() {
                    return TraversableOnce.toBuffer$(this);
                }

                /* renamed from: toSet, reason: merged with bridge method [inline-methods] */
                public <B> Set<B> m71toSet() {
                    return TraversableOnce.toSet$(this);
                }

                public Vector<Row> toVector() {
                    return TraversableOnce.toVector$(this);
                }

                public <Col> Col to(CanBuildFrom<Nothing$, Row, Col> canBuildFrom) {
                    return (Col) TraversableOnce.to$(this, canBuildFrom);
                }

                /* renamed from: toMap, reason: merged with bridge method [inline-methods] */
                public <T, U> Map<T, U> m70toMap(Predef$.less.colon.less<Row, Tuple2<T, U>> lessVar) {
                    return TraversableOnce.toMap$(this, lessVar);
                }

                public String mkString(String str3, String str4, String str5) {
                    return TraversableOnce.mkString$(this, str3, str4, str5);
                }

                public String mkString(String str3) {
                    return TraversableOnce.mkString$(this, str3);
                }

                public String mkString() {
                    return TraversableOnce.mkString$(this);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3, String str4, String str5) {
                    return TraversableOnce.addString$(this, stringBuilder, str3, str4, str5);
                }

                public StringBuilder addString(StringBuilder stringBuilder, String str3) {
                    return TraversableOnce.addString$(this, stringBuilder, str3);
                }

                public StringBuilder addString(StringBuilder stringBuilder) {
                    return TraversableOnce.addString$(this, stringBuilder);
                }

                public int sizeHintIfCheap() {
                    return GenTraversableOnce.sizeHintIfCheap$(this);
                }

                private Iterator<Object> iterInRow() {
                    return this.iterInRow;
                }

                private void iterInRow_$eq(Iterator<Object> iterator) {
                    this.iterInRow = iterator;
                }

                public boolean hasNext() {
                    if (this.iter$1.hasNext() && !iterInRow().hasNext()) {
                        iterInRow_$eq(((Row) this.iter$1.next()).getSeq(1).iterator());
                    }
                    return iterInRow().hasNext();
                }

                /* renamed from: next, reason: merged with bridge method [inline-methods] */
                public Row m76next() {
                    return (Row) iterInRow().next();
                }

                {
                    this.iter$1 = iterator;
                    GenTraversableOnce.$init$(this);
                    TraversableOnce.$init$(this);
                    Iterator.$init$(this);
                    this.iterInRow = scala.package$.MODULE$.Iterator().empty();
                }
            };
        }, RowEncoder$.MODULE$.apply(schema));
    }

    private Map<String, GDFColumnData> prepareDataForGpu(GDFColumnData gDFColumnData, Map<String, GDFColumnData> map, int i, Map<String, Object> map2) {
        if (BoxesRunTime.unboxToBoolean(map2.getOrElse("cacheTrainingSet", () -> {
            return false;
        }))) {
            logger().warn("Data cache is not support for Gpu pipeline!");
        }
        return (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(trainName()), gDFColumnData)})).$plus$plus(map).map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            GDFColumnData gDFColumnData2 = (GDFColumnData) tuple2._2();
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), new GDFColumnData((Dataset) gDFColumnData2.groupColName().map(str2 -> {
                return MODULE$.repartitionForGroup(str2, gDFColumnData2.rawDF(), i);
            }).getOrElse(() -> {
                return gDFColumnData2.rawDF().repartition(i);
            }), gDFColumnData2.colsIndices(), gDFColumnData2.groupColName()));
        }, Map$.MODULE$.canBuildFrom());
    }

    private RDD<Tuple2<String, Iterator<Table>>> coPartitionForGpu(Map<String, GDFColumnData> map, SparkContext sparkContext, int i) {
        return (RDD) map.foldLeft(sparkContext.parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd, tuple2) -> {
            Tuple2 tuple2 = new Tuple2(rdd, tuple2);
            if (tuple2 != null) {
                RDD rdd = (RDD) tuple2._1();
                Tuple2 tuple22 = (Tuple2) tuple2._2();
                if (tuple22 != null) {
                    String str = (String) tuple22._1();
                    return rdd.zipPartitions(PluginUtils$.MODULE$.toColumnarRdd(((GDFColumnData) tuple22._2()).rawDF()), (iterator, iterator2) -> {
                        return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class)))).filter(tuple23 -> {
                            return BoxesRunTime.boxToBoolean($anonfun$coPartitionForGpu$4(tuple23));
                        }))).toIterator();
                    }, ClassTag$.MODULE$.apply(Table.class), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple2);
        });
    }

    private Tuple2<Object, Map<String, Object>> appendGpuIdToParameters(Map<String, Object> map, boolean z) {
        int gpuId = GpuDeviceManager$.MODULE$.getGpuId(z);
        logger().info(new StringBuilder(35).append("XGboost GPU training using device: ").append(gpuId).toString());
        return new Tuple2<>(BoxesRunTime.boxToInteger(gpuId), map.$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("gpu_id"), BoxesRunTime.boxToInteger(gpuId).toString())));
    }

    private RDD<Tuple2<Booster, Map<String, float[]>>> trainPreferGpu(SparkContext sparkContext, Map<String, GDFColumnData> map, boolean z, Map<String, Object> map2, java.util.Map<String, String> map3, int i, Booster booster, boolean z2, Option<GpuSampler> option) throws XGBoostError {
        Map<String, Object> parameterOverrideToUseGPU = parameterOverrideToUseGPU(map2);
        Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation = parameterFetchAndValidation(parameterOverrideToUseGPU, sparkContext);
        if (parameterFetchAndValidation == null) {
            throw new MatchError(parameterFetchAndValidation);
        }
        int unboxToInt = BoxesRunTime.unboxToInt(parameterFetchAndValidation._1());
        boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(parameterFetchAndValidation._3());
        Tuple5 tuple5 = new Tuple5(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToBoolean(unboxToBoolean), (ObjectiveTrait) parameterFetchAndValidation._4(), (EvalTrait) parameterFetchAndValidation._5(), BoxesRunTime.boxToFloat(BoxesRunTime.unboxToFloat(parameterFetchAndValidation._6())));
        int unboxToInt2 = BoxesRunTime.unboxToInt(tuple5._1());
        boolean unboxToBoolean2 = BoxesRunTime.unboxToBoolean(tuple5._2());
        ObjectiveTrait objectiveTrait = (ObjectiveTrait) tuple5._3();
        EvalTrait evalTrait = (EvalTrait) tuple5._4();
        float unboxToFloat = BoxesRunTime.unboxToFloat(tuple5._5());
        boolean isLocal = sparkContext.isLocal();
        if (z) {
            Seq<int[]> colsIndices = ((GDFColumnData) map.apply(trainName())).colsIndices();
            RDD<Table> columnarRdd = PluginUtils$.MODULE$.toColumnarRdd(((GDFColumnData) map.apply(trainName())).rawDF());
            return columnarRdd.mapPartitions(iterator -> {
                Tuple2<Object, Map<String, Object>> appendGpuIdToParameters = MODULE$.appendGpuIdToParameters(parameterOverrideToUseGPU, isLocal);
                if (appendGpuIdToParameters == null) {
                    throw new MatchError(appendGpuIdToParameters);
                }
                int _1$mcI$sp = appendGpuIdToParameters._1$mcI$sp();
                Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToInteger(_1$mcI$sp), (Map) appendGpuIdToParameters._2());
                int _1$mcI$sp2 = tuple2._1$mcI$sp();
                Map<String, Object> map4 = (Map) tuple2._2();
                return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatches(MODULE$.getCacheDirName(unboxToBoolean2), _1$mcI$sp2, unboxToFloat, colsIndices, iterator, z2, option), map4, map3, i, objectiveTrait, evalTrait, booster);
            }, columnarRdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
        }
        Map map4 = (Map) map.map(tuple2 -> {
            return new Tuple2(tuple2._1(), ((GDFColumnData) tuple2._2()).colsIndices());
        }, Map$.MODULE$.canBuildFrom());
        RDD<Tuple2<String, Iterator<Table>>> coPartitionForGpu = coPartitionForGpu(map, sparkContext, unboxToInt2);
        return coPartitionForGpu.mapPartitions(iterator2 -> {
            Tuple2<Object, Map<String, Object>> appendGpuIdToParameters = MODULE$.appendGpuIdToParameters(parameterOverrideToUseGPU, isLocal);
            if (appendGpuIdToParameters == null) {
                throw new MatchError(appendGpuIdToParameters);
            }
            int _1$mcI$sp = appendGpuIdToParameters._1$mcI$sp();
            Tuple2 tuple22 = new Tuple2(BoxesRunTime.boxToInteger(_1$mcI$sp), (Map) appendGpuIdToParameters._2());
            int _1$mcI$sp2 = tuple22._1$mcI$sp();
            Map<String, Object> map5 = (Map) tuple22._2();
            return MODULE$.buildDistributedBooster(Watches$.MODULE$.buildWatchesWithEval(MODULE$.getCacheDirName(unboxToBoolean2), _1$mcI$sp2, unboxToFloat, map4, iterator2, z2, option), map5, map3, i, objectiveTrait, evalTrait, booster);
        }, coPartitionForGpu.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class)).cache();
    }

    private boolean trainPreferGpu$default$8() {
        return false;
    }

    private Option<GpuSampler> trainPreferGpu$default$9() {
        return None$.MODULE$;
    }

    public Tuple2<Booster, Map<String, float[]>> trainDistributedPreferGpu(GDFColumnData gDFColumnData, Map<String, Object> map, Map<String, GDFColumnData> map2, boolean z, Option<GpuSampler> option) throws XGBoostError {
        logger().info(new StringBuilder(38).append("Gpu Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters: ").append(new StringBuilder(1).append("\n").append(map.mkString("\n")).toString()).toString());
        SparkContext sparkContext = gDFColumnData.rawDF().sparkSession().sparkContext();
        Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation = parameterFetchAndValidation(map, sparkContext);
        if (parameterFetchAndValidation == null) {
            throw new MatchError(parameterFetchAndValidation);
        }
        int unboxToInt = BoxesRunTime.unboxToInt(parameterFetchAndValidation._1());
        int unboxToInt2 = BoxesRunTime.unboxToInt(parameterFetchAndValidation._2());
        TrackerConf trackerConf = (TrackerConf) parameterFetchAndValidation._7();
        long unboxToLong = BoxesRunTime.unboxToLong(parameterFetchAndValidation._8());
        Tuple6 tuple6 = new Tuple6(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToInteger(unboxToInt2), trackerConf, BoxesRunTime.boxToLong(unboxToLong), (String) parameterFetchAndValidation._9(), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(parameterFetchAndValidation._10())));
        int unboxToInt3 = BoxesRunTime.unboxToInt(tuple6._1());
        int unboxToInt4 = BoxesRunTime.unboxToInt(tuple6._2());
        TrackerConf trackerConf2 = (TrackerConf) tuple6._3();
        long unboxToLong2 = BoxesRunTime.unboxToLong(tuple6._4());
        String str = (String) tuple6._5();
        int unboxToInt5 = BoxesRunTime.unboxToInt(tuple6._6());
        Map<String, GDFColumnData> prepareDataForGpu = prepareDataForGpu(gDFColumnData, map2, unboxToInt3, map);
        CheckpointManager checkpointManager = new CheckpointManager(sparkContext, str);
        checkpointManager.cleanUpHigherVersions(unboxToInt4);
        ObjectRef create = ObjectRef.create(checkpointManager.loadCheckpointAsBooster());
        return (Tuple2) ((TraversableLike) checkpointManager.getCheckpointRounds(unboxToInt5, unboxToInt4).map(obj -> {
            return $anonfun$trainDistributedPreferGpu$1(unboxToInt3, trackerConf2, map, sparkContext, unboxToLong2, prepareDataForGpu, map2, create, z, option, unboxToInt4, checkpointManager, BoxesRunTime.unboxToInt(obj));
        }, Seq$.MODULE$.canBuildFrom())).last();
    }

    public Tuple2<Booster, Map<String, float[]>> trainDistributed(RDD<LabeledPoint> rdd, Map<String, Object> map, boolean z, Map<String, RDD<LabeledPoint>> map2) throws XGBoostError {
        logger().info(new StringBuilder(34).append("Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters:\n").append(map.mkString("\n")).toString());
        Tuple10<Object, Object, Object, ObjectiveTrait, EvalTrait, Object, TrackerConf, Object, String, Object> parameterFetchAndValidation = parameterFetchAndValidation(map, rdd.sparkContext());
        if (parameterFetchAndValidation == null) {
            throw new MatchError(parameterFetchAndValidation);
        }
        int unboxToInt = BoxesRunTime.unboxToInt(parameterFetchAndValidation._1());
        int unboxToInt2 = BoxesRunTime.unboxToInt(parameterFetchAndValidation._2());
        TrackerConf trackerConf = (TrackerConf) parameterFetchAndValidation._7();
        long unboxToLong = BoxesRunTime.unboxToLong(parameterFetchAndValidation._8());
        Tuple6 tuple6 = new Tuple6(BoxesRunTime.boxToInteger(unboxToInt), BoxesRunTime.boxToInteger(unboxToInt2), trackerConf, BoxesRunTime.boxToLong(unboxToLong), (String) parameterFetchAndValidation._9(), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(parameterFetchAndValidation._10())));
        int unboxToInt3 = BoxesRunTime.unboxToInt(tuple6._1());
        int unboxToInt4 = BoxesRunTime.unboxToInt(tuple6._2());
        TrackerConf trackerConf2 = (TrackerConf) tuple6._3();
        long unboxToLong2 = BoxesRunTime.unboxToLong(tuple6._4());
        String str = (String) tuple6._5();
        int unboxToInt5 = BoxesRunTime.unboxToInt(tuple6._6());
        SparkContext sparkContext = rdd.sparkContext();
        CheckpointManager checkpointManager = new CheckpointManager(sparkContext, str);
        checkpointManager.cleanUpHigherVersions(unboxToInt4);
        Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> composeInputData = composeInputData(rdd, BoxesRunTime.unboxToBoolean(map.getOrElse("cacheTrainingSet", () -> {
            return false;
        })), z, unboxToInt3);
        ObjectRef create = ObjectRef.create(checkpointManager.loadCheckpointAsBooster());
        try {
            try {
                return (Tuple2) ((TraversableLike) checkpointManager.getCheckpointRounds(unboxToInt5, unboxToInt4).map(obj -> {
                    return $anonfun$trainDistributed$2(unboxToInt3, trackerConf2, map, sparkContext, unboxToLong2, z, composeInputData, create, map2, unboxToInt4, checkpointManager, BoxesRunTime.unboxToInt(obj));
                }, Seq$.MODULE$.canBuildFrom())).last();
            } finally {
            }
        } finally {
            uncacheTrainingData(BoxesRunTime.unboxToBoolean(map.getOrElse("cacheTrainingSet", () -> {
                return false;
            })), composeInputData);
        }
    }

    public Map<String, GDFColumnData> trainDistributedPreferGpu$default$3() {
        return Predef$.MODULE$.Map().apply(Nil$.MODULE$);
    }

    public boolean trainDistributedPreferGpu$default$4() {
        return false;
    }

    public Option<GpuSampler> trainDistributedPreferGpu$default$5() {
        return None$.MODULE$;
    }

    public boolean trainDistributed$default$3() {
        return false;
    }

    public Map<String, RDD<LabeledPoint>> trainDistributed$default$4() {
        return Predef$.MODULE$.Map().apply(Nil$.MODULE$);
    }

    private void uncacheTrainingData(boolean z, Either<RDD<LabeledPoint[]>, RDD<LabeledPoint>> either) {
        if (z) {
            if (either.isLeft()) {
                RDD rdd = (RDD) either.left().get();
                rdd.unpersist(rdd.unpersist$default$1());
            } else {
                RDD rdd2 = (RDD) either.right().get();
                rdd2.unpersist(rdd2.unpersist$default$1());
            }
        }
    }

    public RDD<LabeledPoint> repartitionForTraining(RDD<LabeledPoint> rdd, int i) {
        if (rdd.getNumPartitions() == i) {
            return rdd;
        }
        logger().info(new StringBuilder(42).append("repartitioning training set to ").append(i).append(" partitions").toString());
        return rdd.repartition(i, rdd.repartition$default$2(i));
    }

    private RDD<LabeledPoint[]> aggByGroupInfo(RDD<LabeledPoint> rdd) {
        return rdd.mapPartitions(iterator -> {
            return new LabeledPointGroupIterator(iterator);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup -> {
            return BoxesRunTime.boxToBoolean($anonfun$aggByGroupInfo$2(xGBLabeledPointGroup));
        }).map(xGBLabeledPointGroup2 -> {
            return xGBLabeledPointGroup2.points();
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))).union(rdd.mapPartitions(iterator2 -> {
            return new LabeledPointGroupIterator(iterator2);
        }, rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(XGBLabeledPointGroup.class)).filter(xGBLabeledPointGroup3 -> {
            return BoxesRunTime.boxToBoolean(xGBLabeledPointGroup3.isEdgeGroup());
        }).map(xGBLabeledPointGroup4 -> {
            return new Tuple2(BoxesRunTime.boxToInteger(TaskContext$.MODULE$.getPartitionId()), xGBLabeledPointGroup4);
        }, ClassTag$.MODULE$.apply(Tuple2.class)).groupBy(tuple2 -> {
            return BoxesRunTime.boxToInteger($anonfun$aggByGroupInfo$7(tuple2));
        }, ClassTag$.MODULE$.Int()).map(tuple22 -> {
            return (LabeledPoint[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) ((Iterable) tuple22._2()).toArray(ClassTag$.MODULE$.apply(Tuple2.class)))).sortBy(tuple22 -> {
                return BoxesRunTime.boxToInteger(tuple22._1$mcI$sp());
            }, Ordering$Int$.MODULE$))).flatMap(tuple23 -> {
                return new ArrayOps.ofRef($anonfun$aggByGroupInfo$10(tuple23));
            }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(LabeledPoint.class)));
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class))));
    }

    public RDD<LabeledPoint[]> repartitionForTrainingGroup(RDD<LabeledPoint> rdd, int i) {
        RDD<LabeledPoint[]> aggByGroupInfo = aggByGroupInfo(rdd);
        logger().info(new StringBuilder(48).append("repartitioning training group set to ").append(i).append(" partitions").toString());
        return aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i));
    }

    private RDD<Tuple2<String, Iterator<LabeledPoint[]>>> coPartitionGroupSets(RDD<LabeledPoint[]> rdd, Map<String, RDD<LabeledPoint>> map, int i) {
        return (RDD) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("train"), rdd)})).$plus$plus((GenTraversableOnce) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            RDD<LabeledPoint[]> aggByGroupInfo = MODULE$.aggByGroupInfo((RDD) tuple2._2());
            return aggByGroupInfo.getNumPartitions() != i ? Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo.repartition(i, aggByGroupInfo.repartition$default$2(i))) : Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), aggByGroupInfo);
        }, Map$.MODULE$.canBuildFrom())).foldLeft(rdd.sparkContext().parallelize(Predef$.MODULE$.wrapRefArray((Object[]) Array$.MODULE$.fill(i, () -> {
            return null;
        }, ClassTag$.MODULE$.apply(Tuple2.class))), i, ClassTag$.MODULE$.apply(Tuple2.class)), (rdd2, tuple22) -> {
            Tuple2 tuple22 = new Tuple2(rdd2, tuple22);
            if (tuple22 != null) {
                RDD rdd2 = (RDD) tuple22._1();
                Tuple2 tuple23 = (Tuple2) tuple22._2();
                if (tuple23 != null) {
                    String str = (String) tuple23._1();
                    return rdd2.zipPartitions((RDD) tuple23._2(), (iterator, iterator2) -> {
                        if (iterator2.hasNext()) {
                            Tuple2[] tuple2Arr = (Tuple2[]) iterator.toArray(ClassTag$.MODULE$.apply(Tuple2.class));
                            return new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).head() != null ? new XGBoost.IteratorWrapper((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(tuple2Arr)).$colon$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2), ClassTag$.MODULE$.apply(Tuple2.class))) : new XGBoost.IteratorWrapper(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), iterator2)});
                        }
                        MODULE$.logger().error("when specifying eval sets as dataframes, you have to ensure that the number of elements in each dataframe is larger than the number of workers");
                        throw new Exception("too few elements in evaluation sets");
                    }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(LabeledPoint.class)), ClassTag$.MODULE$.apply(Tuple2.class));
                }
            }
            throw new MatchError(tuple22);
        });
    }

    private Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing(int i, RDD<Tuple2<Booster, Map<String, float[]>>> rdd, Thread thread) {
        if (i != 0) {
            try {
                if (thread.isAlive()) {
                    thread.interrupt();
                }
            } catch (InterruptedException unused) {
                logger().info("spark job thread is interrupted");
            }
            throw new XGBoostError("XGBoostModel training failed");
        }
        thread.join();
        Tuple2 tuple2 = (Tuple2) rdd.first();
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Tuple2 tuple22 = new Tuple2((Booster) tuple2._1(), (Map) tuple2._2());
        Booster booster = (Booster) tuple22._1();
        Map map = (Map) tuple22._2();
        rdd.unpersist(false);
        return new Tuple2<>(booster, map);
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$removeMissingValues$2(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public static final /* synthetic */ boolean $anonfun$removeMissingValues$3(Function1 function1, Tuple2 tuple2) {
        if (tuple2 != null) {
            return function1.apply$mcZF$sp(BoxesRunTime.unboxToFloat(tuple2._1()));
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ int $anonfun$buildDistributedBooster$1(Object obj) {
        return new StringOps(Predef$.MODULE$.augmentString(obj.toString())).toInt();
    }

    public static final /* synthetic */ float[] $anonfun$buildDistributedBooster$3(int i, int i2) {
        return (float[]) Array$.MODULE$.ofDim(i, ClassTag$.MODULE$.Float());
    }

    public static final /* synthetic */ void $anonfun$buildDistributedBooster$5(TaskContext taskContext) {
        System.gc();
        System.runFinalization();
    }

    public static final /* synthetic */ boolean $anonfun$coPartitionForGpu$4(Tuple2 tuple2) {
        return tuple2 != null;
    }

    public static final /* synthetic */ Tuple2 $anonfun$trainDistributedPreferGpu$1(int i, TrackerConf trackerConf, Map map, SparkContext sparkContext, long j, Map map2, Map map3, ObjectRef objectRef, boolean z, Option option, int i2, CheckpointManager checkpointManager, int i3) {
        Thread.UncaughtExceptionHandler startTracker = MODULE$.startTracker(i, trackerConf);
        try {
            Map<String, Object> overrideParamsAccordingToTaskCPUs = MODULE$.overrideParamsAccordingToTaskCPUs(map, sparkContext);
            SparkParallelismTracker sparkParallelismTracker = new SparkParallelismTracker(sparkContext, j, i);
            final RDD<Tuple2<Booster, Map<String, float[]>>> trainPreferGpu = MODULE$.trainPreferGpu(sparkContext, map2, map3.isEmpty(), overrideParamsAccordingToTaskCPUs, startTracker.getWorkerEnvs(), i3, (Booster) objectRef.elem, z, option);
            Thread thread = new Thread(trainPreferGpu) { // from class: ml.dmlc.xgboost4j.scala.spark.XGBoost$$anon$2
                private final RDD boostersAndMetrics$1;

                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    this.boostersAndMetrics$1.foreachPartition(iterator -> {
                        () -> {
                            return iterator;
                        };
                        return BoxedUnit.UNIT;
                    });
                }

                {
                    this.boostersAndMetrics$1 = trainPreferGpu;
                }
            };
            thread.setUncaughtExceptionHandler(startTracker);
            thread.start();
            int unboxToInt = BoxesRunTime.unboxToInt(sparkParallelismTracker.executeHonorForGpu(() -> {
                return startTracker.waitFor(0L);
            }));
            MODULE$.logger().info(new StringBuilder(33).append("Gpu Rabit returns with exit code ").append(unboxToInt).toString());
            Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing = MODULE$.postTrackerReturnProcessing(unboxToInt, trainPreferGpu, thread);
            if (postTrackerReturnProcessing == null) {
                throw new MatchError(postTrackerReturnProcessing);
            }
            Tuple2 tuple2 = new Tuple2((Booster) postTrackerReturnProcessing._1(), (Map) postTrackerReturnProcessing._2());
            Booster booster = (Booster) tuple2._1();
            Map map4 = (Map) tuple2._2();
            if (i3 < i2) {
                objectRef.elem = booster;
                checkpointManager.updateCheckpoint((Booster) objectRef.elem);
            }
            return new Tuple2(booster, map4);
        } finally {
            startTracker.stop();
        }
    }

    public static final /* synthetic */ Tuple2 $anonfun$trainDistributed$2(int i, TrackerConf trackerConf, Map map, SparkContext sparkContext, long j, boolean z, Either either, ObjectRef objectRef, Map map2, int i2, CheckpointManager checkpointManager, int i3) {
        Thread.UncaughtExceptionHandler startTracker = MODULE$.startTracker(i, trackerConf);
        try {
            Map<String, Object> overrideParamsAccordingToTaskCPUs = MODULE$.overrideParamsAccordingToTaskCPUs(map, sparkContext);
            SparkParallelismTracker sparkParallelismTracker = new SparkParallelismTracker(sparkContext, j, i);
            java.util.Map<String, String> workerEnvs = startTracker.getWorkerEnvs();
            final RDD<Tuple2<Booster, Map<String, float[]>>> trainForRanking = z ? MODULE$.trainForRanking((RDD) either.left().get(), overrideParamsAccordingToTaskCPUs, workerEnvs, i3, (Booster) objectRef.elem, map2) : MODULE$.trainForNonRanking((RDD) either.right().get(), overrideParamsAccordingToTaskCPUs, workerEnvs, i3, (Booster) objectRef.elem, map2);
            Thread thread = new Thread(trainForRanking) { // from class: ml.dmlc.xgboost4j.scala.spark.XGBoost$$anon$3
                private final RDD boostersAndMetrics$2;

                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    this.boostersAndMetrics$2.foreachPartition(iterator -> {
                        () -> {
                            return iterator;
                        };
                        return BoxedUnit.UNIT;
                    });
                }

                {
                    this.boostersAndMetrics$2 = trainForRanking;
                }
            };
            thread.setUncaughtExceptionHandler(startTracker);
            thread.start();
            int unboxToInt = BoxesRunTime.unboxToInt(sparkParallelismTracker.execute(() -> {
                return startTracker.waitFor(0L);
            }));
            MODULE$.logger().info(new StringBuilder(29).append("Rabit returns with exit code ").append(unboxToInt).toString());
            Tuple2<Booster, Map<String, float[]>> postTrackerReturnProcessing = MODULE$.postTrackerReturnProcessing(unboxToInt, trainForRanking, thread);
            if (postTrackerReturnProcessing == null) {
                throw new MatchError(postTrackerReturnProcessing);
            }
            Tuple2 tuple2 = new Tuple2((Booster) postTrackerReturnProcessing._1(), (Map) postTrackerReturnProcessing._2());
            Booster booster = (Booster) tuple2._1();
            Map map3 = (Map) tuple2._2();
            if (i3 < i2) {
                objectRef.elem = booster;
                checkpointManager.updateCheckpoint((Booster) objectRef.elem);
            }
            return new Tuple2(booster, map3);
        } finally {
            startTracker.stop();
        }
    }

    public static final /* synthetic */ boolean $anonfun$aggByGroupInfo$2(XGBLabeledPointGroup xGBLabeledPointGroup) {
        return !xGBLabeledPointGroup.isEdgeGroup();
    }

    public static final /* synthetic */ int $anonfun$aggByGroupInfo$7(Tuple2 tuple2) {
        return ((XGBLabeledPointGroup) tuple2._2()).groupId();
    }

    public static final /* synthetic */ Object[] $anonfun$aggByGroupInfo$10(Tuple2 tuple2) {
        return Predef$.MODULE$.refArrayOps(((XGBLabeledPointGroup) tuple2._2()).points());
    }

    private XGBoost$() {
        MODULE$ = this;
        this.logger = LogFactory.getLog("XGBoostSpark");
        this.trainName = "train";
    }
}
