package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.GeneralizedLinearRegression;
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 广义线性回归
 */
public class GeneralizedLinearRegressionWrapper implements DataSetOperator {

    public static GeneralizedLinearRegression getOperator(String labelCol,
                                                          String featuresCol,
                                                          String family,
                                                          String link,
                                                          Integer maxIterations,
                                                          Double regParam,
                                                          Double tol,
                                                          Boolean fitIntercept,
                                                          String weightCol,
                                                          String solver,
                                                          String linkPredictionCol) {

        GeneralizedLinearRegression regressor = new GeneralizedLinearRegression();

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(family)) {
            regressor.setFamily(family);
        }

        if (!Strings.isNullOrEmpty(link)) {
            regressor.setLink(link);
        }

        if (maxIterations != null) {
            regressor.setMaxIter(maxIterations);
        }

        if (regParam != null) {
            regressor.setRegParam(regParam);
        }

        if (tol != null) {
            regressor.setTol(tol);
        }

        if (fitIntercept != null) {
            regressor.setFitIntercept(fitIntercept);
        }

        if (!Strings.isNullOrEmpty(weightCol)) {
            regressor.setWeightCol(weightCol);
        }

        if (!Strings.isNullOrEmpty(solver)) {
            regressor.setSolver(solver);
        }

        if (!Strings.isNullOrEmpty(linkPredictionCol)) {
            regressor.setLinkPredictionCol(linkPredictionCol);
        }

        return regressor;
    }

    public static GeneralizedLinearRegressionModel fit(Dataset<Row> data,
                                                       String labelCol,
                                                       String featuresCol,
                                                       String family,
                                                       String link,
                                                       Integer maxIterations,
                                                       Double regParam,
                                                       Double tol,
                                                       Boolean fitIntercept,
                                                       String weightCol,
                                                       String solver,
                                                       String linkPredictionCol) {
        GeneralizedLinearRegression regressor = getOperator(labelCol,
                featuresCol,
                family,
                link,
                maxIterations,
                regParam,
                tol,
                fitIntercept,
                weightCol,
                solver,
                linkPredictionCol);
        return regressor.fit(data);
    }

    public static GeneralizedLinearRegressionModel fit(GeneralizedLinearRegression regressor, Dataset<Row> data) {
        return regressor.fit(data);
    }

    public static Dataset<Row> transform(GeneralizedLinearRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
