package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.KNNRegression;
import org.apache.spark.ml.regression.KNNRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * Created by huangping on 20/04/2017.
 */
public class KNNRegressionWrapper implements DataSetOperator {

    public KNNRegression getOperator(Double balanceThreshold,
                                     String featuresCol,
                                     Integer k,
                                     String labelCol,
                                     Long seed,
                                     Integer subTreeLeafSize,
                                     Integer topTreeLeafSize,
                                     Integer topTreeSize,
                                     String weightCol,
                                     String predictionCol) {
        KNNRegression regressor = new KNNRegression();

        if (topTreeSize != null) {
            regressor.setTopTreeSize(topTreeSize);
        }

        if (balanceThreshold != null) {
            regressor.setBalanceThreshold(balanceThreshold);
        }

        if (!Strings.isNullOrEmpty(featuresCol)) {
            regressor.setFeaturesCol(featuresCol);
        }

        if (k != null) {
            regressor.setK(k);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (seed != null) {
            regressor.setSeed(seed);
        }

        if (subTreeLeafSize != null) {
            regressor.setSubTreeLeafSize(subTreeLeafSize);
        }

        if (topTreeLeafSize != null) {
            regressor.setTopTreeLeafSize(topTreeLeafSize);
        }

        if (!Strings.isNullOrEmpty(weightCol)) {
            regressor.setWeightCol(weightCol);
        }

        if (!Strings.isNullOrEmpty(predictionCol)) {
            regressor.setPredictionCol(predictionCol);
        }

        return regressor;
    }

    public KNNRegressionModel fit(Dataset<Row> data,
                                  Double balanceThreshold,
                                  String featuresCol,
                                  Integer k,
                                  String labelCol,
                                  Long seed,
                                  Integer subTreeLeafSize,
                                  Integer topTreeLeafSize,
                                  Integer topTreeSize,
                                  String weightCol,
                                  String predictionCol) {

        KNNRegression regressor = getOperator(balanceThreshold,
                featuresCol,
                k,
                labelCol,
                seed,
                subTreeLeafSize,
                topTreeLeafSize,
                topTreeSize,
                weightCol,
                predictionCol);
        return regressor.fit(data);
    }

    public KNNRegressionModel fit(KNNRegression regression, Dataset<Row> data) {
        return regression.fit(data);
    }

    public Dataset<Row> transform(KNNRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
