/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.tree.impl;

import java.io.File;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.tree.impl.GradientBoostedTrees$;
import org.apache.spark.ml.tree.impl.GradientBoostedTreesSuite$;
import org.apache.spark.ml.util.TempDirectory$class;
import org.apache.spark.mllib.tree.GradientBoostedTreesSuite$;
import org.apache.spark.mllib.tree.configuration.Algo$;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy$;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.configuration.Strategy$;
import org.apache.spark.mllib.tree.impurity.Impurity;
import org.apache.spark.mllib.tree.impurity.Variance$;
import org.apache.spark.mllib.tree.loss.AbsoluteError$;
import org.apache.spark.mllib.tree.loss.LogLoss$;
import org.apache.spark.mllib.tree.loss.Loss;
import org.apache.spark.mllib.tree.loss.SquaredError$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$class;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.TripleEqualsSupport;
import org.scalatest.BeforeAndAfterAll;
import org.scalatest.Tag;
import scala.Array$;
import scala.Enumeration;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Seq;
import scala.collection.immutable.Map;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.reflect.api.JavaMirrors;
import scala.reflect.api.JavaUniverse;
import scala.reflect.api.Mirror;
import scala.reflect.api.Symbols;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u001d2A!\u0001\u0002\u0001\u001f\tIrI]1eS\u0016tGOQ8pgR,G\r\u0016:fKN\u001cV/\u001b;f\u0015\t\u0019A!\u0001\u0003j[Bd'BA\u0003\u0007\u0003\u0011!(/Z3\u000b\u0005\u001dA\u0011AA7m\u0015\tI!\"A\u0003ta\u0006\u00148N\u0003\u0002\f\u0019\u00051\u0011\r]1dQ\u0016T\u0011!D\u0001\u0004_J<7\u0001A\n\u0005\u0001A!B\u0004\u0005\u0002\u0012%5\t\u0001\"\u0003\u0002\u0014\u0011\ti1\u000b]1sW\u001a+hnU;ji\u0016\u0004\"!\u0006\u000e\u000e\u0003YQ!a\u0006\r\u0002\tU$\u0018\u000e\u001c\u0006\u00033!\tQ!\u001c7mS\nL!a\u0007\f\u0003+5cE.\u001b2UKN$8\u000b]1sW\u000e{g\u000e^3yiB\u0011Q\u0004I\u0007\u0002=)\u0011q\u0004C\u0001\tS:$XM\u001d8bY&\u0011\u0011E\b\u0002\b\u0019><w-\u001b8h\u0011\u0015\u0019\u0003\u0001\"\u0001%\u0003\u0019a\u0014N\\5u}Q\tQ\u0005\u0005\u0002'\u00015\t!\u0001")
public class GradientBoostedTreesSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient SparkSession spark;
    private transient SparkContext sc;
    private transient String checkpointDir;
    private File org$apache$spark$ml$util$TempDirectory$$_tempDir;

    @Override
    public SparkSession spark() {
        return this.spark;
    }

    @Override
    public void spark_$eq(SparkSession x$1) {
        this.spark = x$1;
    }

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public String checkpointDir() {
        return this.checkpointDir;
    }

    @Override
    public void checkpointDir_$eq(String x$1) {
        this.checkpointDir = x$1;
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        TempDirectory$class.beforeAll(this);
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        TempDirectory$class.afterAll(this);
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext$class.beforeAll(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext$class.afterAll(this);
    }

    @Override
    public File org$apache$spark$ml$util$TempDirectory$$_tempDir() {
        return this.org$apache$spark$ml$util$TempDirectory$$_tempDir;
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$_tempDir_$eq(File x$1) {
        this.org$apache$spark$ml$util$TempDirectory$$_tempDir = x$1;
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$super$beforeAll() {
        BeforeAndAfterAll.class.beforeAll((BeforeAndAfterAll)this);
    }

    @Override
    public void org$apache$spark$ml$util$TempDirectory$$super$afterAll() {
        super.afterAll();
    }

    @Override
    public File tempDir() {
        return TempDirectory$class.tempDir(this);
    }

    public GradientBoostedTreesSuite() {
        TempDirectory$class.$init$(this);
        MLlibTestSparkContext$class.$init$(this);
        this.test("runWithValidation stops early and performs better on a validation dataset", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ GradientBoostedTreesSuite $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                int numIterations = 20;
                RDD trainRdd = this.$outer.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])GradientBoostedTreesSuite$.MODULE$.trainData()), 2, ClassTag$.MODULE$.apply(org.apache.spark.mllib.regression.LabeledPoint.class)).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LabeledPoint apply(org.apache.spark.mllib.regression.LabeledPoint x$1) {
                        return x$1.asML();
                    }
                }, ClassTag$.MODULE$.apply(LabeledPoint.class));
                RDD validateRdd = this.$outer.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])GradientBoostedTreesSuite$.MODULE$.validateData()), 2, ClassTag$.MODULE$.apply(org.apache.spark.mllib.regression.LabeledPoint.class)).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final LabeledPoint apply(org.apache.spark.mllib.regression.LabeledPoint x$2) {
                        return x$2.asML();
                    }
                }, ClassTag$.MODULE$.apply(LabeledPoint.class));
                JavaUniverse $u = package$.MODULE$.universe();
                JavaMirrors.JavaMirror $m = package$.MODULE$.universe().runtimeMirror(GradientBoostedTreesSuite.class.getClassLoader());
                Dataset trainDF = this.$outer.spark().createDataFrame(trainRdd, ((TypeTags)$u).TypeTag().apply((Mirror)$m, new TypeCreator(this){

                    public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                        Universe $u = $m$untyped.universe();
                        Mirror<U> $m = $m$untyped;
                        return ((Symbols.TypeSymbolApi)((Symbols.TypeSymbolApi)$m.staticClass("org.apache.spark.ml.feature.LabeledPoint")).asType()).toTypeConstructor();
                    }
                }));
                JavaUniverse $u2 = package$.MODULE$.universe();
                JavaMirrors.JavaMirror $m2 = package$.MODULE$.universe().runtimeMirror(GradientBoostedTreesSuite.class.getClassLoader());
                Dataset validateDF = this.$outer.spark().createDataFrame(validateRdd, ((TypeTags)$u2).TypeTag().apply((Mirror)$m2, new TypeCreator(this){

                    public <U extends Universe> Types.TypeApi apply(Mirror<U> $m$untyped) {
                        Universe $u = $m$untyped.universe();
                        Mirror<U> $m = $m$untyped;
                        return ((Symbols.TypeSymbolApi)((Symbols.TypeSymbolApi)$m.staticClass("org.apache.spark.ml.feature.LabeledPoint")).asType()).toTypeConstructor();
                    }
                }));
                Enumeration.Value[] algos = (Enumeration.Value[])((Object[])new Enumeration.Value[]{Algo$.MODULE$.Regression(), Algo$.MODULE$.Regression(), Algo$.MODULE$.Classification()});
                Loss[] losses = (Loss[])((Object[])new Loss[]{SquaredError$.MODULE$, AbsoluteError$.MODULE$, LogLoss$.MODULE$});
                Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])algos).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])losses), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach((Function1)new Serializable(this, numIterations, trainRdd, validateRdd){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$1 $outer;
                    private final int numIterations$1;
                    private final RDD trainRdd$1;
                    private final RDD validateRdd$1;

                    public final void apply(Tuple2<Enumeration.Value, Loss> x0$1) {
                        Tuple2<Enumeration.Value, Loss> tuple2 = x0$1;
                        if (tuple2 != null) {
                            Strategy treeStrategy;
                            Enumeration.Value algo = (Enumeration.Value)tuple2._1();
                            Loss loss = (Loss)tuple2._2();
                            Enumeration.Value x$6 = algo;
                            Variance$ x$7 = Variance$.MODULE$;
                            int x$8 = 2;
                            Map x$9 = Predef$.MODULE$.Map().empty();
                            int x$10 = Strategy$.MODULE$.$lessinit$greater$default$4();
                            int x$11 = Strategy$.MODULE$.$lessinit$greater$default$5();
                            Enumeration.Value x$12 = Strategy$.MODULE$.$lessinit$greater$default$6();
                            int x$13 = Strategy$.MODULE$.$lessinit$greater$default$8();
                            double x$14 = Strategy$.MODULE$.$lessinit$greater$default$9();
                            int x$15 = Strategy$.MODULE$.$lessinit$greater$default$10();
                            double x$16 = Strategy$.MODULE$.$lessinit$greater$default$11();
                            boolean x$17 = Strategy$.MODULE$.$lessinit$greater$default$12();
                            int x$18 = Strategy$.MODULE$.$lessinit$greater$default$13();
                            Strategy x$19 = treeStrategy = new Strategy(x$6, (Impurity)x$7, x$8, x$10, x$11, x$12, x$9, x$13, x$14, x$15, x$16, x$17, x$18);
                            Loss x$20 = loss;
                            int x$21 = this.numIterations$1;
                            double x$22 = 0.0;
                            double x$23 = BoostingStrategy$.MODULE$.$lessinit$greater$default$4();
                            BoostingStrategy boostingStrategy = new BoostingStrategy(x$19, x$20, x$21, x$23, x$22);
                            Tuple2 tuple22 = GradientBoostedTrees$.MODULE$.runWithValidation(this.trainRdd$1, this.validateRdd$1, boostingStrategy, 42L);
                            if (tuple22 != null) {
                                Tuple2 tuple23;
                                DecisionTreeRegressionModel[] validateTrees = (DecisionTreeRegressionModel[])tuple22._1();
                                double[] validateTreeWeights = (double[])tuple22._2();
                                Tuple2 tuple24 = tuple23 = new Tuple2((Object)validateTrees, (Object)validateTreeWeights);
                                DecisionTreeRegressionModel[] validateTrees2 = (DecisionTreeRegressionModel[])tuple24._1();
                                double[] validateTreeWeights2 = (double[])tuple24._2();
                                int numTrees = validateTrees2.length;
                                TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().convertToEqualizer(BoxesRunTime.boxToInteger((int)numTrees));
                                int $org_scalatest_assert_macro_right = this.numIterations$1;
                                Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "!==", (Object)BoxesRunTime.boxToInteger((int)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$bang$eq$eq((Object)BoxesRunTime.boxToInteger((int)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()));
                                this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"");
                                Tuple2 tuple25 = GradientBoostedTrees$.MODULE$.run(this.trainRdd$1, boostingStrategy, 42L);
                                if (tuple25 != null) {
                                    Tuple2.mcDD.sp sp2;
                                    Tuple2.mcDD.sp sp3;
                                    Tuple2 tuple26;
                                    DecisionTreeRegressionModel[] trees = (DecisionTreeRegressionModel[])tuple25._1();
                                    double[] treeWeights = (double[])tuple25._2();
                                    Tuple2 tuple27 = tuple26 = new Tuple2((Object)trees, (Object)treeWeights);
                                    DecisionTreeRegressionModel[] trees2 = (DecisionTreeRegressionModel[])tuple27._1();
                                    double[] treeWeights2 = (double[])tuple27._2();
                                    Enumeration.Value value = algo;
                                    Enumeration.Value value2 = Algo$.MODULE$.Classification();
                                    if (!(value != null ? !value.equals(value2) : value2 != null)) {
                                        RDD remappedRdd = this.validateRdd$1.map((Function1)new Serializable(this){
                                            public static final long serialVersionUID = 0L;

                                            public final LabeledPoint apply(LabeledPoint x) {
                                                return new LabeledPoint((double)2 * x.label() - 1.0, x.features());
                                            }
                                        }, ClassTag$.MODULE$.apply(LabeledPoint.class));
                                        sp3 = new Tuple2.mcDD.sp(GradientBoostedTrees$.MODULE$.computeError(remappedRdd, trees2, treeWeights2, loss), GradientBoostedTrees$.MODULE$.computeError(remappedRdd, validateTrees2, validateTreeWeights2, loss));
                                    } else {
                                        sp3 = sp2 = new Tuple2.mcDD.sp(GradientBoostedTrees$.MODULE$.computeError(this.validateRdd$1, trees2, treeWeights2, loss), GradientBoostedTrees$.MODULE$.computeError(this.validateRdd$1, validateTrees2, validateTreeWeights2, loss));
                                    }
                                    if (sp2 != null) {
                                        double errorWithValidation;
                                        Tuple2.mcDD.sp sp4;
                                        double errorWithoutValidation = sp2._1$mcD$sp();
                                        double errorWithValidation2 = sp2._2$mcD$sp();
                                        Tuple2.mcDD.sp sp5 = sp4 = new Tuple2.mcDD.sp(errorWithoutValidation, errorWithValidation2);
                                        double errorWithoutValidation2 = sp5._1$mcD$sp();
                                        double $org_scalatest_assert_macro_left2 = errorWithValidation = sp5._2$mcD$sp();
                                        double $org_scalatest_assert_macro_right2 = errorWithoutValidation2;
                                        Bool $org_scalatest_assert_macro_expr2 = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left2), "<=", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right2), $org_scalatest_assert_macro_left2 <= $org_scalatest_assert_macro_right2);
                                        this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr2, (Object)"");
                                        double[] evaluationArray = GradientBoostedTrees$.MODULE$.evaluateEachIteration(this.validateRdd$1, trees2, treeWeights2, loss, algo);
                                        TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left3 = this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().convertToEqualizer(BoxesRunTime.boxToInteger((int)evaluationArray.length));
                                        int $org_scalatest_assert_macro_right3 = this.numIterations$1;
                                        Bool $org_scalatest_assert_macro_expr3 = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left3, "===", (Object)BoxesRunTime.boxToInteger((int)$org_scalatest_assert_macro_right3), $org_scalatest_assert_macro_left3.$eq$eq$eq((Object)BoxesRunTime.boxToInteger((int)$org_scalatest_assert_macro_right3), Equality$.MODULE$.default()));
                                        this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr3, (Object)"");
                                        double $org_scalatest_assert_macro_left4 = evaluationArray[numTrees];
                                        double $org_scalatest_assert_macro_right4 = evaluationArray[numTrees - 1];
                                        Bool $org_scalatest_assert_macro_expr4 = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left4), ">", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right4), $org_scalatest_assert_macro_left4 > $org_scalatest_assert_macro_right4);
                                        this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr4, (Object)"");
                                        for (int i = 1; i < numTrees; ++i) {
                                            double $org_scalatest_assert_macro_left5 = evaluationArray[i];
                                            double $org_scalatest_assert_macro_right5 = evaluationArray[i - 1];
                                            Bool $org_scalatest_assert_macro_expr5 = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left5), "<=", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right5), $org_scalatest_assert_macro_left5 <= $org_scalatest_assert_macro_right5);
                                            this.$outer.org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr5, (Object)"");
                                        }
                                        BoxedUnit boxedUnit = BoxedUnit.UNIT;
                                        return;
                                    }
                                    throw new MatchError((Object)sp2);
                                }
                                throw new MatchError((Object)tuple25);
                            }
                            throw new MatchError((Object)tuple22);
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.numIterations$1 = numIterations$1;
                        this.trainRdd$1 = trainRdd$1;
                        this.validateRdd$1 = validateRdd$1;
                    }
                });
            }

            public /* synthetic */ GradientBoostedTreesSuite org$apache$spark$ml$tree$impl$GradientBoostedTreesSuite$$anonfun$$$outer() {
                return this.$outer;
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
            }
        });
    }
}

