package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/classification/RandomForestClassifierWrapper.class */
public class RandomForestClassifierWrapper implements DataSetOperator {
    public static RandomForestClassifier getOperator(String str, String str2, Integer num, Integer num2, Integer num3, Double d, Integer num4, Boolean bool, Integer num5, String str3, Double d2, Integer num6, String str4) {
        RandomForestClassifier randomForestClassifier = new RandomForestClassifier();
        if (!Strings.isNullOrEmpty(str)) {
            randomForestClassifier.setLabelCol(str);
        }
        if (!Strings.isNullOrEmpty(str2)) {
            randomForestClassifier.setFeaturesCol(str2);
        }
        if (num != null) {
            randomForestClassifier.setMaxDepth(num.intValue());
        }
        if (num2 != null) {
            randomForestClassifier.setMaxBins(num2.intValue());
        }
        if (num3 != null) {
            randomForestClassifier.setMinInstancesPerNode(num3.intValue());
        }
        if (d != null) {
            randomForestClassifier.setMinInfoGain(d.doubleValue());
        }
        if (num4 != null) {
            randomForestClassifier.setMaxMemoryInMB(num4.intValue());
        }
        if (bool != null) {
            randomForestClassifier.setCacheNodeIds(bool.booleanValue());
        }
        if (num5 != null) {
            randomForestClassifier.setCheckpointInterval(num5.intValue());
        }
        if (!Strings.isNullOrEmpty(str3)) {
            randomForestClassifier.setImpurity(str3);
        }
        if (!Strings.isNullOrEmpty(str4)) {
            randomForestClassifier.setFeatureSubsetStrategy(str4);
        }
        if (d2 != null) {
            randomForestClassifier.setSubsamplingRate(d2.doubleValue());
        }
        if (num6 != null) {
            randomForestClassifier.setNumTrees(num6.intValue());
        }
        return randomForestClassifier;
    }

    public static RandomForestClassificationModel fit(Dataset<Row> dataset, String str, String str2, Integer num, Integer num2, Integer num3, Double d, Integer num4, Boolean bool, Integer num5, String str3, Double d2, Integer num6, String str4) {
        return getOperator(str, str2, num, num2, num3, d, num4, bool, num5, str3, d2, num6, str4).fit(dataset);
    }

    public static RandomForestClassificationModel fit(RandomForestClassifier randomForestClassifier, Dataset<Row> dataset) {
        return randomForestClassifier.fit(dataset);
    }

    public static Dataset<Row> transform(RandomForestClassificationModel randomForestClassificationModel, Dataset<Row> dataset) {
        return randomForestClassificationModel.transform(dataset);
    }
}
