package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.service.PersistService;
import com.google.common.base.Strings;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 回归评估
 */
public class RegressionEvaluatorWrapper implements DataSetOperator {
    /**
     * 回归评估器
     */
    public static RegressionEvaluator getOperator(String labelCol, String predictionCol) {
        RegressionEvaluator evaluator = new RegressionEvaluator();

        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (predictionCol != null && predictionCol.length() > 0) {
            evaluator.setPredictionCol(predictionCol);
        }

        return evaluator;
    }

    /**
     * 回归评估
     */
    public static double evaluate(RegressionEvaluator evaluator, Dataset<Row> predictions) {
        return evaluator.evaluate(predictions);
    }

//    public static Metrics evaluate1(RegressionEvaluator evaluator, Dataset<Row> predictions){
//
//        JavaRDD<Tuple2<Object, Object>> predictionAndLabels = predictions
//                .select(evaluator.getPredictionCol(), evaluator.getLabelCol()).javaRDD()
//                .map(r-> new Tuple2<>(r.get(0), r.get(1)));
//
//        RegressionMetrics regressionMetrics = new RegressionMetrics(predictionAndLabels.rdd());
//        Metrics metrics = new Metrics();
//        metrics.setMse(regressionMetrics.meanSquaredError());
//        metrics.setRmse(regressionMetrics.rootMeanSquaredError());
//        metrics.setMae(regressionMetrics.meanAbsoluteError());
//        metrics.setR2(regressionMetrics.r2());
//        metrics.setExplainedVariance(regressionMetrics.explainedVariance());
//
//        PersistService.saveMetrics(metrics);
//
//        return metrics;
//    }

    public static Metrics evalute(Transformer transformer, Dataset<Row> testData, String labelCol, String predictionCol) {
        RegressionEvaluator evaluator = new RegressionEvaluator();

        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (predictionCol != null && predictionCol.length() > 0) {
            evaluator.setPredictionCol(predictionCol);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            testData = testData.withColumn(labelCol, testData.col(labelCol).cast("double"));
        }

        Dataset<Row> preditions = transformer.transform(testData);

        Metrics metrics = new Metrics();

        evaluator.setMetricName("rmse");
        double rmse = evaluator.evaluate(preditions);
        metrics.getIndicator().setRmse(rmse);

        evaluator.setMetricName("mse");
        double mse = evaluator.evaluate(preditions);
        metrics.getIndicator().setMse(mse);

        evaluator.setMetricName("r2");
        double r2 = evaluator.evaluate(preditions);
        metrics.getIndicator().setR2(r2);

        evaluator.setMetricName("mae");
        double mae = evaluator.evaluate(preditions);
        metrics.getIndicator().setMae(mae);

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
