package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.KNNClassificationModel;
import org.apache.spark.ml.classification.KNNClassifier;
import org.apache.spark.sql.Dataset;

public class KNNClassifierWrapper implements DataSetOperator {

    public KNNClassifier getOperator(Double balanceThreshold,
                                     String featuresCol,
                                     String labelCol,
                                     Integer k,
                                     Long seed,
                                     Integer subTreeLeafSize,
                                     Integer topTreeLeafSize,
                                     Integer topTreeSize,
                                     String weightCol,
                                     String predictionCol,
                                     String probabilityCol) {
        KNNClassifier classifier = new KNNClassifier();

        if(topTreeSize != null) {
            classifier.setTopTreeSize(topTreeSize);
        }

        if(k != null) {
            classifier.setK(k);
        }

        if(balanceThreshold != null) {
            classifier.setBalanceThreshold(balanceThreshold);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(seed != null) {
            classifier.setSeed(seed);
        }

        if(subTreeLeafSize != null) {
            classifier.setSubTreeLeafSize(subTreeLeafSize);
        }

        if(topTreeLeafSize != null) {
            classifier.setTopTreeLeafSize(topTreeLeafSize);
        }

        if(!Strings.isNullOrEmpty(weightCol)) {
            classifier.setWeightCol(weightCol);
        }

        if(!Strings.isNullOrEmpty(predictionCol)) {
            classifier.setPredictionCol(predictionCol);
        }

        if(!Strings.isNullOrEmpty(probabilityCol)) {
            classifier.setProbabilityCol(probabilityCol);
        }

        return classifier;
    }

    public KNNClassificationModel fit(Dataset data,
                                      Double balanceThreshold,
                                      String featuresCol,
                                      String labelCol,
                                      Integer k,
                                      Long seed,
                                      Integer subTreeLeafSize,
                                      Integer topTreeLeafSize,
                                      Integer topTreeSize,
                                      String weightCol,
                                      String predictionCol,
                                      String probabilityCol) {

        KNNClassifier classifier = getOperator(balanceThreshold,
                featuresCol,
                labelCol,
                k,
                seed,
                subTreeLeafSize,
                topTreeLeafSize,
                topTreeSize,
                weightCol,
                predictionCol,
                probabilityCol);

        return classifier.fit(data);
    }

    public KNNClassificationModel fit(KNNClassifier classifier, Dataset data) {
        return classifier.fit(data);
    }

    public Dataset transform(KNNClassificationModel model, Dataset data) {
        return model.transform(data);
    }
}
