package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 线性回归
 */
public class LinearRegressionWrapper implements DataSetOperator {

    public static LinearRegression getOperator(Integer maxIterations,
                                               Double regParam,
                                               Double elasticNetParam,
                                               Double tol,
                                               Boolean fitIntercept,
                                               Boolean standardization,
                                               String weightCol,
                                               String solver,
                                               String featuresCol,
                                               String labelCol,
                                               String predictionCol) {

        LinearRegression regressor = new LinearRegression();

        if (maxIterations != null) {
            regressor.setMaxIter(maxIterations);
        }

        if (regParam != null) {
            regressor.setRegParam(regParam);
        }

        if (elasticNetParam != null) {
            regressor.setElasticNetParam(elasticNetParam);
        }

        if (tol != null) {
            regressor.setTol(tol);
        }

        if (fitIntercept != null) {
            regressor.setFitIntercept(fitIntercept);
        }

        if (standardization != null) {
            regressor.setStandardization(standardization);
        }

        if (!Strings.isNullOrEmpty(weightCol)) {
            regressor.setWeightCol(weightCol);
        }

        if (!Strings.isNullOrEmpty(solver)) {
            regressor.setSolver(solver);
        }

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(predictionCol)) {
            regressor.setPredictionCol(predictionCol);
        }

        return regressor;
    }

    public static LinearRegressionModel fit(Dataset<Row> data,
                                            Integer maxIterations,
                                            Double regParam,
                                            Double elasticNetParam,
                                            Double tol,
                                            Boolean fitIntercept,
                                            Boolean standardization,
                                            String weightCol,
                                            String solver,
                                            String featuresCol,
                                            String labelCol,
                                            String predictionCol) {
        LinearRegression regressor = getOperator(maxIterations,
                regParam,
                elasticNetParam,
                tol,
                fitIntercept,
                standardization,
                weightCol,
                solver,
                featuresCol,
                labelCol,
                predictionCol);
        return regressor.fit(data);
    }

    public static LinearRegressionModel fit(LinearRegression regressor, Dataset<Row> data) {
        return regressor.fit(data);
    }

    public Dataset<Row> transform(LinearRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
