/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.GBTClassificationModel;
import org.apache.spark.ml.classification.GBTClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class GradientBoostedTreeClassifier
implements DataSetOperator {
    public static GBTClassifier getOperator(String labelCol, String featuresCol, Integer maxIterations, Integer maxDepth, Integer maxBins, Integer minInstancesPerNode, Double minInfoGain, Integer maxMemoryInMB, Boolean cacheNodeIds, Integer checkpointInterval, String impurity, Double subsamplingRate, Long seed, Double stepSize, String lossType) {
        GBTClassifier classifier = new GBTClassifier();
        if (!Strings.isNullOrEmpty((String)labelCol)) {
            classifier.setLabelCol(labelCol);
        }
        if (!Strings.isNullOrEmpty((String)featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }
        if (maxIterations != null) {
            classifier.setMaxIter(maxIterations.intValue());
        }
        if (maxDepth != null) {
            classifier.setMaxDepth(maxDepth.intValue());
        }
        if (maxBins != null) {
            classifier.setMaxBins(maxBins.intValue());
        }
        if (minInstancesPerNode != null) {
            classifier.setMinInstancesPerNode(minInstancesPerNode.intValue());
        }
        if (minInfoGain != null) {
            classifier.setMinInfoGain(minInfoGain.doubleValue());
        }
        if (maxMemoryInMB != null) {
            classifier.setMaxMemoryInMB(maxMemoryInMB.intValue());
        }
        if (cacheNodeIds != null) {
            classifier.setCacheNodeIds(cacheNodeIds.booleanValue());
        }
        if (checkpointInterval != null) {
            classifier.setCheckpointInterval(checkpointInterval.intValue());
        }
        if (!Strings.isNullOrEmpty((String)impurity)) {
            classifier.setImpurity(impurity);
        }
        if (subsamplingRate != null) {
            classifier.setSubsamplingRate(subsamplingRate.doubleValue());
        }
        if (subsamplingRate != null) {
            classifier.setSubsamplingRate(subsamplingRate.doubleValue());
        }
        if (seed != null) {
            classifier.setSeed(seed.longValue());
        }
        if (stepSize != null) {
            classifier.setStepSize(stepSize.doubleValue());
        }
        if (!Strings.isNullOrEmpty((String)lossType)) {
            classifier.setLossType(lossType);
        }
        return classifier;
    }

    public static GBTClassificationModel fit(Dataset<Row> data, String labelCol, String featuresCol, Integer maxIterations, Integer maxDepth, Integer maxBins, Integer minInstancesPerNode, Double minInfoGain, Integer maxMemoryInMB, Boolean cacheNodeIds, Integer checkpointInterval, String impurity, Double subsamplingRate, Long seed, Double stepSize, String lossType) {
        GBTClassifier classifier = GradientBoostedTreeClassifier.getOperator(labelCol, featuresCol, maxIterations, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, subsamplingRate, seed, stepSize, lossType);
        return (GBTClassificationModel)classifier.fit(data);
    }

    public static GBTClassificationModel fit(GBTClassifier classifier, Dataset<Row> data) {
        return (GBTClassificationModel)classifier.fit(data);
    }

    public static Dataset<Row> transform(GBTClassificationModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}

