/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class RandomForestRegressionWrapper
implements DataSetOperator {
    public static RandomForestRegressor getOperator(String labelCol, String featuresCol, Integer maxDepth, Integer maxBins, Integer minInstancesPerNode, Double minInfoGain, Integer maxMemoryInMB, Boolean cacheNodeIds, Integer checkpointInterval, String impurity, Double subsamplingRate, Integer numTrees, String featureSubsetStrategy) {
        RandomForestRegressor regressor = new RandomForestRegressor();
        if (!Strings.isNullOrEmpty((String)labelCol)) {
            regressor.setLabelCol(labelCol);
        }
        if (!Strings.isNullOrEmpty((String)featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }
        if (maxDepth != null) {
            regressor.setMaxDepth(maxDepth.intValue());
        }
        if (maxBins != null) {
            regressor.setMaxBins(maxBins.intValue());
        }
        if (minInstancesPerNode != null) {
            regressor.setMinInstancesPerNode(minInstancesPerNode.intValue());
        }
        if (minInfoGain != null) {
            regressor.setMinInfoGain(minInfoGain.doubleValue());
        }
        if (maxMemoryInMB != null) {
            regressor.setMaxMemoryInMB(maxMemoryInMB.intValue());
        }
        if (cacheNodeIds != null) {
            regressor.setCacheNodeIds(cacheNodeIds.booleanValue());
        }
        if (checkpointInterval != null) {
            regressor.setCheckpointInterval(checkpointInterval.intValue());
        }
        if (!Strings.isNullOrEmpty((String)impurity)) {
            regressor.setImpurity(impurity);
        }
        if (subsamplingRate != null) {
            regressor.setSubsamplingRate(subsamplingRate.doubleValue());
        }
        if (numTrees != null) {
            regressor.setNumTrees(numTrees.intValue());
        }
        if (!Strings.isNullOrEmpty((String)featureSubsetStrategy)) {
            regressor.setFeatureSubsetStrategy(featureSubsetStrategy);
        }
        return regressor;
    }

    public static RandomForestRegressionModel fit(Dataset<Row> data, String labelCol, String featuresCol, Integer maxDepth, Integer maxBins, Integer minInstancesPerNode, Double minInfoGain, Integer maxMemoryInMB, Boolean cacheNodeIds, Integer checkpointInterval, String impurity, Double subsamplingRate, Integer numTrees, String featureSubsetStrategy) {
        RandomForestRegressor regressor = RandomForestRegressionWrapper.getOperator(labelCol, featuresCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, subsamplingRate, numTrees, featureSubsetStrategy);
        return (RandomForestRegressionModel)regressor.fit(data);
    }

    public static RandomForestRegressionModel fit(RandomForestRegressor regressor, Dataset<Row> data) {
        return (RandomForestRegressionModel)regressor.fit(data);
    }

    public static Dataset<Row> transform(RandomForestRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}

