package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/classification/XgboostClassifier.class */
public class XgboostClassifier implements DataSetOperator {
    @InsightComponent(name = "XGBoostClassifier", description = "XGBoost分类器")
    public static XGBoostClassifier getOperator(@InsightComponentArg(name = "标签列", description = "标签列", request = true) String str, @InsightComponentArg(name = "特征列", description = "特征列", request = true) String str2, @InsightComponentArg(name = "预测列", description = "预测列", request = true) String str3, @InsightComponentArg(name = "步长", description = "步长", request = true) Double d, @InsightComponentArg(name = "类别数量", description = "类别数量", request = true) Integer num, @InsightComponentArg(name = "树最大深度", description = "树最大深度", request = true) Integer num2, @InsightComponentArg(name = "目标函数", description = "目标函数", request = true) String str4, @InsightComponentArg(name = "迭代次数", description = "迭代次数", request = true) Integer num3, @InsightComponentArg(name = "线程数", description = "线程数", request = true) Integer num4, @InsightComponentArg(name = "gamma系数", description = "gamma系数", request = true) Double d2) {
        XGBoostClassifier xGBoostClassifier = new XGBoostClassifier();
        if (!Strings.isNullOrEmpty(str)) {
            xGBoostClassifier.setLabelCol(str);
        }
        if (!Strings.isNullOrEmpty(str2)) {
            xGBoostClassifier.setFeaturesCol(str2);
        }
        if (!Strings.isNullOrEmpty(str3)) {
            xGBoostClassifier.setPredictionCol(str3);
        }
        if (d != null) {
            xGBoostClassifier.setEta(d.doubleValue());
        }
        if (num != null) {
            xGBoostClassifier.setNumClass(num.intValue());
        }
        if (num2 != null) {
            xGBoostClassifier.setMaxDepth(num2.intValue());
        }
        if (str4 != null) {
            xGBoostClassifier.setObjective(str4);
        }
        if (num3 != null) {
            xGBoostClassifier.setNumRound(num3.intValue());
        }
        if (num4 != null) {
            xGBoostClassifier.setNumWorkers(num4.intValue());
        }
        if (d2 != null) {
            xGBoostClassifier.setGamma(d2.doubleValue());
        }
        return xGBoostClassifier;
    }

    public static XGBoostClassificationModel fit(Dataset<Row> dataset, String str, String str2, String str3, Integer num, Integer num2, Double d, String str4, Integer num3, Integer num4, Double d2) {
        return getOperator(str, str2, str3, d, num, num2, str4, num3, num4, d2).fit(dataset);
    }

    public static XGBoostClassificationModel fit(XGBoostClassifier xGBoostClassifier, Dataset<Row> dataset) {
        return xGBoostClassifier.fit(dataset);
    }

    public static Dataset<Row> transform(XGBoostClassificationModel xGBoostClassificationModel, Dataset<Row> dataset) {
        return xGBoostClassificationModel.transform(dataset);
    }
}
