package com.datastax.insight.ml.spark.ml.regression;

import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import com.google.common.base.Strings;
import org.apache.spark.ml.regression.AFTSurvivalRegression;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 生存回归
 */
public class SurvivalRegressionWrapper implements DataSetOperator {

    public static AFTSurvivalRegression getOperator(String labelCol,
                                                    String censorCol,
                                                    String predictionCol,
                                                    String qp,
                                                    String quantilesCol,
                                                    Integer maxIterations,
                                                    Double tol,
                                                    Boolean fitIntercept) {

        AFTSurvivalRegression regressor = new AFTSurvivalRegression();

        if (!Strings.isNullOrEmpty(labelCol)) {
            regressor.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(censorCol)) {
            regressor.setCensorCol(censorCol);
        }

        if (!Strings.isNullOrEmpty(predictionCol)) {
            regressor.setPredictionCol(predictionCol);
        }

        if (!Strings.isNullOrEmpty(quantilesCol)) {
            regressor.setQuantilesCol(quantilesCol);
        }

        if (!Strings.isNullOrEmpty(qp)) {

            double[] quantileProbabilities;
            String[] qps = qp.split(Consts.DELIMITER);
            quantileProbabilities = new double[qps.length];

            for (int i = 0; i < quantileProbabilities.length; i++) {
                quantileProbabilities[i] = Double.parseDouble(qps[i]);
            }

            if (quantileProbabilities != null) {
                regressor.setQuantileProbabilities(quantileProbabilities);
            }
        }

        if (maxIterations != null) {
            regressor.setMaxIter(maxIterations);
        }

        if (tol != null) {
            regressor.setTol(tol);
        }

        if (fitIntercept != null) {
            regressor.setFitIntercept(fitIntercept);
        }

        return regressor;
    }

    public static AFTSurvivalRegressionModel fit(Dataset<Row> data,
                                                 String labelCol,
                                                 String censorCol,
                                                 String predictionCol,
                                                 String qp,
                                                 String quantilesCol,
                                                 Integer maxIterations,
                                                 Double tol,
                                                 Boolean fitIntercept) {
        AFTSurvivalRegression regressor = getOperator(labelCol,
                censorCol,
                predictionCol,
                qp,
                quantilesCol,
                maxIterations,
                tol,
                fitIntercept);
        return regressor.fit(data);
    }

    public static AFTSurvivalRegressionModel fit(AFTSurvivalRegression regressor, Dataset<Row> data) {
        return regressor.fit(data);
    }

    public Dataset<Row> transform(AFTSurvivalRegressionModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
