package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.LinearSVC;
import org.apache.spark.ml.classification.LinearSVCModel;
import org.apache.spark.sql.Dataset;

/**
 * 线性支持向量机分类
 */
public class LinearSVCWrapper implements DataSetOperator {

    public static LinearSVC getOperator(Integer aggregationDepth, Boolean fitIntercept, Integer maxIter,
                                        Double regParam, Boolean standardization, Double threshold, Double tol,
                                        String weightCol, String featuresCol, String labelCol, String predictionCol) {
        LinearSVC classifier = new LinearSVC();

        if(aggregationDepth != null) {
            classifier.setAggregationDepth(aggregationDepth);
        }

        if(fitIntercept != null) {
            classifier.setFitIntercept(fitIntercept);
        }

        if(maxIter != null) {
            classifier.setMaxIter(maxIter);
        }

        if(regParam != null) {
            classifier.setRegParam(regParam);
        }

        if (standardization != null) {
            classifier.setStandardization(standardization);
        }

        if (threshold != null) {
            classifier.setThreshold(threshold);
        }

        if(tol != null) {
            classifier.setTol(tol);
        }

        if(!Strings.isNullOrEmpty(weightCol)) {
            classifier.setWeightCol(weightCol);
        }

        if(!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if(!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if(!Strings.isNullOrEmpty(predictionCol)) {
            classifier.setPredictionCol(predictionCol);
        }

        return classifier;
    }

    public static LinearSVCModel fit(LinearSVC classifier, Dataset dataset) {
        return classifier.fit(dataset);
    }

    public static Dataset transform(LinearSVCModel model, Dataset dataset) {
        return model.transform(dataset);
    }
}
