package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class XgboostClassifier implements DataSetOperator {
    @InsightComponent(name = "XGBoostClassifier", description = "XGBoost分类器")
    public static XGBoostClassifier getOperator(
            @InsightComponentArg(name = "标签列", description = "标签列",request = true)String labelCol,
            @InsightComponentArg(name = "特征列", description = "特征列",request = true)String featuresCol,
            @InsightComponentArg(name = "预测列", description = "预测列",request = true)String predictCol,
            @InsightComponentArg(name = "步长", description = "步长",request = true)Double eta,
            @InsightComponentArg(name = "类别数量", description = "类别数量",request = true)Integer numClass,
            @InsightComponentArg(name = "树最大深度", description = "树最大深度",request = true)Integer maxDepth,
            @InsightComponentArg(name = "目标函数", description = "目标函数",request = true)String objective,
            @InsightComponentArg(name = "迭代次数", description = "迭代次数",request = true)Integer numRound,
            @InsightComponentArg(name = "线程数", description = "线程数",request = true)Integer num_workers,
            @InsightComponentArg(name = "gamma系数", description = "gamma系数",request = true)Double gamma){
        XGBoostClassifier classifier = new XGBoostClassifier();

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(!Strings.isNullOrEmpty(predictCol)) {
            classifier.setPredictionCol(predictCol);
        }

        if(eta != null) {
            classifier.setEta(eta);
        }

        if(numClass != null) {
            classifier.setNumClass(numClass);
        }

        if(maxDepth != null) {
            classifier.setMaxDepth(maxDepth);
        }

        if(objective != null) {
            classifier.setObjective(objective);
        }

        if(numRound != null) {
            classifier.setNumRound(numRound);
        }

        if(num_workers != null) {
            classifier.setNumWorkers(num_workers);
        }

        if(gamma != null) {
            classifier.setGamma(gamma);
        }

        return classifier;
    }

    public static XGBoostClassificationModel fit(Dataset<Row> data,
                                                 String labelCol,
                                                 String featuresCol,
                                                 String predictCol,
                                                 Integer numClass,
                                                 Integer maxDepth,
                                                 Double eta,
                                                 String objective,
                                                 Integer numRound,
                                                 Integer num_workers,
                                                 Double gamma){
        XGBoostClassifier xgBoostClassifier = getOperator(labelCol, featuresCol, predictCol, eta,numClass, maxDepth,
                objective, numRound, num_workers, gamma);
        return xgBoostClassifier.fit(data);
    }

    public static XGBoostClassificationModel fit(XGBoostClassifier classifier,Dataset<Row> data){
        return classifier.fit(data);
    }

    public static Dataset<Row> transform(XGBoostClassificationModel model,Dataset<Row> data){
        return model.transform(data);
    }

}
