package org.apache.spark.examples.mllib;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.examples.mllib.GradientBoostedTreesRunner;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees$;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy$;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDD$;
import org.apache.spark.util.Utils$;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.Tuple3;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.sys.package$;
import scopt.OptionParser;
import scopt.Read$;

/* compiled from: GradientBoostedTreesRunner.scala */
/* loaded from: input_file:org/apache/spark/examples/mllib/GradientBoostedTreesRunner$.class */
public final class GradientBoostedTreesRunner$ {
    public static GradientBoostedTreesRunner$ MODULE$;

    static {
        new GradientBoostedTreesRunner$();
    }

    public void main(String[] strArr) {
        final GradientBoostedTreesRunner.Params params = new GradientBoostedTreesRunner.Params(GradientBoostedTreesRunner$Params$.MODULE$.apply$default$1(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$2(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$3(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$4(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$5(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$6(), GradientBoostedTreesRunner$Params$.MODULE$.apply$default$7());
        Some parse = new OptionParser<GradientBoostedTreesRunner.Params>(params) { // from class: org.apache.spark.examples.mllib.GradientBoostedTreesRunner$$anon$1
            public static final /* synthetic */ GradientBoostedTreesRunner.Params $anonfun$new$2(int i, GradientBoostedTreesRunner.Params params2) {
                return params2.copy(params2.copy$default$1(), params2.copy$default$2(), params2.copy$default$3(), params2.copy$default$4(), i, params2.copy$default$6(), params2.copy$default$7());
            }

            public static final /* synthetic */ GradientBoostedTreesRunner.Params $anonfun$new$3(int i, GradientBoostedTreesRunner.Params params2) {
                return params2.copy(params2.copy$default$1(), params2.copy$default$2(), params2.copy$default$3(), params2.copy$default$4(), params2.copy$default$5(), i, params2.copy$default$7());
            }

            public static final /* synthetic */ GradientBoostedTreesRunner.Params $anonfun$new$4(double d, GradientBoostedTreesRunner.Params params2) {
                return params2.copy(params2.copy$default$1(), params2.copy$default$2(), params2.copy$default$3(), params2.copy$default$4(), params2.copy$default$5(), params2.copy$default$6(), d);
            }

            {
                super("GradientBoostedTrees");
                head(Predef$.MODULE$.wrapRefArray(new String[]{"GradientBoostedTrees: an example decision tree app."}));
                opt("algo", Read$.MODULE$.stringRead()).text(new StringBuilder(23).append("algorithm (").append(Algo$.MODULE$.values().mkString(",")).append("), default: ").append(params.algo()).toString()).action((str, params2) -> {
                    return params2.copy(params2.copy$default$1(), params2.copy$default$2(), params2.copy$default$3(), str, params2.copy$default$5(), params2.copy$default$6(), params2.copy$default$7());
                });
                opt("maxDepth", Read$.MODULE$.intRead()).text(new StringBuilder(32).append("max depth of the tree, default: ").append(params.maxDepth()).toString()).action((obj, params3) -> {
                    return $anonfun$new$2(BoxesRunTime.unboxToInt(obj), params3);
                });
                opt("numIterations", Read$.MODULE$.intRead()).text(new StringBuilder(43).append("number of iterations of boosting,").append(" default: ").append(params.numIterations()).toString()).action((obj2, params4) -> {
                    return $anonfun$new$3(BoxesRunTime.unboxToInt(obj2), params4);
                });
                opt("fracTest", Read$.MODULE$.doubleRead()).text(new StringBuilder(103).append("fraction of data to hold out for testing.  If given option testInput, ").append("this option is ignored. default: ").append(params.fracTest()).toString()).action((obj3, params5) -> {
                    return $anonfun$new$4(BoxesRunTime.unboxToDouble(obj3), params5);
                });
                opt("testInput", Read$.MODULE$.stringRead()).text(new StringBuilder(76).append("input path to test dataset.  If given, option fracTest is ignored.").append(" default: ").append(params.testInput()).toString()).action((str2, params6) -> {
                    return params6.copy(params6.copy$default$1(), str2, params6.copy$default$3(), params6.copy$default$4(), params6.copy$default$5(), params6.copy$default$6(), params6.copy$default$7());
                });
                opt("dataFormat", Read$.MODULE$.stringRead()).text("data format: libsvm (default), dense (deprecated in Spark v1.1)").action((str3, params7) -> {
                    return params7.copy(params7.copy$default$1(), params7.copy$default$2(), str3, params7.copy$default$4(), params7.copy$default$5(), params7.copy$default$6(), params7.copy$default$7());
                });
                arg("<input>", Read$.MODULE$.stringRead()).text("input path to labeled examples").required().action((str4, params8) -> {
                    return params8.copy(str4, params8.copy$default$2(), params8.copy$default$3(), params8.copy$default$4(), params8.copy$default$5(), params8.copy$default$6(), params8.copy$default$7());
                });
                checkConfig(params9 -> {
                    return (params9.fracTest() < ((double) 0) || params9.fracTest() > ((double) 1)) ? this.failure(new StringBuilder(46).append("fracTest ").append(params9.fracTest()).append(" value incorrect; should be in [0,1].").toString()) : this.success();
                });
            }
        }.parse(Predef$.MODULE$.wrapRefArray(strArr), params);
        if (!(parse instanceof Some)) {
            throw package$.MODULE$.exit(1);
        }
        run((GradientBoostedTreesRunner.Params) parse.value());
        BoxedUnit boxedUnit = BoxedUnit.UNIT;
    }

    public void run(GradientBoostedTreesRunner.Params params) {
        SparkContext sparkContext = new SparkContext(new SparkConf().setAppName(new StringBuilder(32).append("GradientBoostedTreesRunner with ").append(params).toString()));
        Predef$.MODULE$.println(new StringBuilder(44).append("GradientBoostedTreesRunner with parameters:\n").append(params).toString());
        Tuple3<RDD<LabeledPoint>, RDD<LabeledPoint>, Object> loadDatasets = DecisionTreeRunner$.MODULE$.loadDatasets(sparkContext, params.input(), params.dataFormat(), params.testInput(), Algo$.MODULE$.withName(params.algo()), params.fracTest());
        if (loadDatasets == null) {
            throw new MatchError(loadDatasets);
        }
        Tuple3 tuple3 = new Tuple3((RDD) loadDatasets._1(), (RDD) loadDatasets._2(), BoxesRunTime.boxToInteger(BoxesRunTime.unboxToInt(loadDatasets._3())));
        RDD<LabeledPoint> rdd = (RDD) tuple3._1();
        RDD<LabeledPoint> rdd2 = (RDD) tuple3._2();
        int unboxToInt = BoxesRunTime.unboxToInt(tuple3._3());
        BoostingStrategy defaultParams = BoostingStrategy$.MODULE$.defaultParams(params.algo());
        defaultParams.treeStrategy().numClasses_$eq(unboxToInt);
        defaultParams.numIterations_$eq(params.numIterations());
        defaultParams.treeStrategy().maxDepth_$eq(params.maxDepth());
        Utils$.MODULE$.random().nextInt();
        String algo = params.algo();
        if (algo != null ? !algo.equals("Classification") : "Classification" != 0) {
            String algo2 = params.algo();
            if (algo2 != null ? algo2.equals("Regression") : "Regression" == 0) {
                long nanoTime = System.nanoTime();
                GradientBoostedTreesModel train = GradientBoostedTrees$.MODULE$.train(rdd, defaultParams);
                Predef$.MODULE$.println(new StringBuilder(23).append("Training time: ").append((System.nanoTime() - nanoTime) / 1.0E9d).append(" seconds").toString());
                if (train.totalNumNodes() < 30) {
                    Predef$.MODULE$.println(train.toDebugString());
                } else {
                    Predef$.MODULE$.println(train);
                }
                Predef$.MODULE$.println(new StringBuilder(27).append("Train mean squared error = ").append(meanSquaredError(train, rdd)).toString());
                Predef$.MODULE$.println(new StringBuilder(26).append("Test mean squared error = ").append(meanSquaredError(train, rdd2)).toString());
            }
        } else {
            long nanoTime2 = System.nanoTime();
            GradientBoostedTreesModel train2 = GradientBoostedTrees$.MODULE$.train(rdd, defaultParams);
            Predef$.MODULE$.println(new StringBuilder(23).append("Training time: ").append((System.nanoTime() - nanoTime2) / 1.0E9d).append(" seconds").toString());
            if (train2.totalNumNodes() < 30) {
                Predef$.MODULE$.println(train2.toDebugString());
            } else {
                Predef$.MODULE$.println(train2);
            }
            Predef$.MODULE$.println(new StringBuilder(17).append("Train accuracy = ").append(new MulticlassMetrics(rdd.map(labeledPoint -> {
                return new Tuple2.mcDD.sp(train2.predict(labeledPoint.features()), labeledPoint.label());
            }, ClassTag$.MODULE$.apply(Tuple2.class))).accuracy()).toString());
            Predef$.MODULE$.println(new StringBuilder(16).append("Test accuracy = ").append(new MulticlassMetrics(rdd2.map(labeledPoint2 -> {
                return new Tuple2.mcDD.sp(train2.predict(labeledPoint2.features()), labeledPoint2.label());
            }, ClassTag$.MODULE$.apply(Tuple2.class))).accuracy()).toString());
        }
        sparkContext.stop();
    }

    public double meanSquaredError(GradientBoostedTreesModel gradientBoostedTreesModel, RDD<LabeledPoint> rdd) {
        return RDD$.MODULE$.doubleRDDToDoubleRDDFunctions(rdd.map(labeledPoint -> {
            return BoxesRunTime.boxToDouble($anonfun$meanSquaredError$1(gradientBoostedTreesModel, labeledPoint));
        }, ClassTag$.MODULE$.Double())).mean();
    }

    public static final /* synthetic */ double $anonfun$meanSquaredError$1(GradientBoostedTreesModel gradientBoostedTreesModel, LabeledPoint labeledPoint) {
        double predict = gradientBoostedTreesModel.predict(labeledPoint.features()) - labeledPoint.label();
        return predict * predict;
    }

    private GradientBoostedTreesRunner$() {
        MODULE$ = this;
    }
}
