package com.datastax.insight.ml.spark.ml.tuning;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 训练验证分离
 */
public class TrainValidationSplitWrapper implements DataSetOperator {

    public static TrainValidationSplit getOperator(Estimator estimator,
                                                   ParamMap[] paramMaps,
                                                   Evaluator evaluator,
                                                   Double trainRatio,
                                                   Long seed) {
        TrainValidationSplit split = new TrainValidationSplit();

        if(estimator != null) {
            split.setEstimator(estimator);
        }

        if(paramMaps != null) {
            split.setEstimatorParamMaps(paramMaps);
        }

        if(evaluator != null) {
            split.setEvaluator(evaluator);
        }

        if(seed != null) {
            split.setSeed(seed);
        }

        if(trainRatio != null) {
            split.setTrainRatio(trainRatio);
        }

        return split;
    }

    public static TrainValidationSplitModel fit(Dataset<Row> dataset,
                                                Estimator estimator,
                                                ParamMap[] paramMaps,
                                                Evaluator evaluator,
                                                Double trainRatio,
                                                Long seed) {
        TrainValidationSplit operator = getOperator(estimator, paramMaps, evaluator, trainRatio, seed);
        return fit(dataset, operator);
    }

    public static TrainValidationSplitModel fit(Dataset<Row> dataset, TrainValidationSplit operator) {
        TrainValidationSplitModel model = operator.fit(dataset);
        System.out.println(JSON.toJSONString(model.bestModel().extractParamMap()));
        return model;
    }

    public static Dataset<Row> transform(Dataset<Row> dataset, TrainValidationSplitModel model) {
        return model.transform(dataset);
    }
}
