package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * OneVsRest分类训练
 */
public class OneVsRestClassifier implements DataSetOperator {

    public static OneVsRest getOperator(String featuresCol,
                                        String labelCol,
                                        Integer maxIterations,
                                        Double regParam,
                                        Double elasticNetParam,
                                        Double threshold,
                                        Double tol,
                                        Boolean fitIntercept,
                                        Boolean standardization,
                                        String weightCol) {

        LogisticRegression classifier = LogisticRegressionWrapper.getOperator(featuresCol, labelCol, maxIterations, regParam, elasticNetParam,
                threshold, tol, fitIntercept, standardization, weightCol);

        return new OneVsRest().setClassifier(classifier);
    }

    public static OneVsRest getOperator(Classifier regression){
        return new OneVsRest().setClassifier(regression);
    }

    public static OneVsRestModel fit(Dataset<Row> data, String featuresCol,
                                     String labelCol,
                                     Integer maxIterations,
                                     Double regParam,
                                     Double elasticNetParam,
                                     Double threshold,
                                     Double tol,
                                     Boolean fitIntercept,
                                     Boolean standardization,
                                     String weightCol) {

        LogisticRegression regression = LogisticRegressionWrapper.getOperator(featuresCol, labelCol, maxIterations, regParam, elasticNetParam,
                threshold, tol, fitIntercept, standardization, weightCol);
        OneVsRest ovr = new OneVsRest().setClassifier(regression);
        return ovr.fit(data);
    }

    public static OneVsRestModel fit(LogisticRegression regression,Dataset<Row> data){
        OneVsRest ovr = new OneVsRest().setClassifier(regression);
        return ovr.fit(data);
    }

    public static Dataset<Row> transform(OneVsRestModel model,Dataset<Row> data){
        return model.transform(data);
    }
}
