/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.mllib.classification;

import com.datastax.insight.spec.RDDOperator;
import java.util.HashMap;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import scala.Tuple2;

public class GradientBoosting
implements RDDOperator {
    public static GradientBoostedTreesModel trainClassifier(JavaRDD<LabeledPoint> data, int numIterations, int numClasses, int maxDepth) {
        return GradientBoosting.train(data, "Classification", numIterations, numClasses, maxDepth);
    }

    public static GradientBoostedTreesModel trainRegressor(JavaRDD<LabeledPoint> data, int numIterations, int numClasses, int maxDepth) {
        return GradientBoosting.train(data, "Regression", numIterations, numClasses, maxDepth);
    }

    public static GradientBoostedTreesModel train(JavaRDD<LabeledPoint> data, String defaultParams, int numIterations, int numClasses, int maxDepth) {
        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams((String)defaultParams);
        boostingStrategy.setNumIterations(numIterations);
        boostingStrategy.getTreeStrategy().setNumClasses(numClasses);
        boostingStrategy.getTreeStrategy().setMaxDepth(maxDepth);
        HashMap categoricalFeaturesInfo = new HashMap();
        boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);
        GradientBoostedTreesModel model = GradientBoostedTrees.train(data, (BoostingStrategy)boostingStrategy);
        return model;
    }

    public static JavaPairRDD<Double, Double> predict(JavaRDD<LabeledPoint> data, final GradientBoostedTreesModel model) {
        JavaPairRDD predictionAndLabel = data.mapToPair((PairFunction)new PairFunction<LabeledPoint, Double, Double>(){

            public Tuple2<Double, Double> a(LabeledPoint p2) {
                return new Tuple2((Object)model.predict(p2.features()), (Object)p2.label());
            }

            public /* synthetic */ Tuple2 call(Object object) throws Exception {
                return this.a((LabeledPoint)object);
            }
        });
        return predictionAndLabel;
    }
}

