package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 逻辑回归分类
 */
public class LogisticRegressionWrapper implements DataSetOperator {
    /**
     * 逻辑回归分类器
     */
    public static LogisticRegression getOperator(String featuresCol,
                                                 String labelCol,
                                                 Integer maxIterations,
                                                 Double regParam,
                                                 Double elasticNetParam,
                                                 Double threshold,
                                                 Double tol,
                                                 Boolean fitIntercept,
                                                 Boolean standardization,
                                                 String weightCol) {

        LogisticRegression classifier = new LogisticRegression();

        if (!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(maxIterations != null) {
            classifier.setMaxIter(maxIterations);
        }

        if(regParam != null) {
            classifier.setRegParam(regParam);
        }

        if(elasticNetParam != null) {
            classifier.setElasticNetParam(elasticNetParam);
        }

        if(threshold != null) {
            classifier.setThreshold(threshold);
        }

        if(fitIntercept != null) {
            classifier.setFitIntercept(fitIntercept);
        }

        if(tol != null) {
            classifier.setTol(tol);
        }

        if(standardization != null) {
            classifier.setStandardization(standardization);
        }

        if (!Strings.isNullOrEmpty(weightCol)) {
            classifier.setWeightCol(weightCol);
        }

        return classifier;
    }

    public static LogisticRegressionModel fit(Dataset<Row> data,
                                              String featuresCol,
                                              String labelCol,
                                              Integer maxIterations,
                                              Double regParam,
                                              Double elasticNetParam,
                                              Double threshold,
                                              Double tol,
                                              Boolean fitIntercept,
                                              Boolean standardization,
                                              String weightCol) {
        LogisticRegression regression = getOperator(featuresCol, labelCol, maxIterations, regParam, elasticNetParam,
                threshold, tol, fitIntercept, standardization, weightCol);

//        LogisticRegressionModel lrModel = regression.fit(data);
//        System.out.println("Coefficients: "
//                + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
        return regression.fit(data);
    }

    /**
     * 逻辑回归分类训练
     */
    public static LogisticRegressionModel fit(LogisticRegression regression, Dataset<Row> data) {
//        LogisticRegressionModel lrModel = regression.fit(data);
//        System.out.println("Coefficients: "
//                + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
        return regression.fit(data);
    }

    /**
     * 逻辑回归分类预测
     */
    public static Dataset<Row> transform(LogisticRegressionModel model, Dataset<Row> data){
        return model.transform(data);
    }
}
