package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 多类分类评估
 */
public class MulticlassClassificationEvaluatorWrapper implements DataSetOperator {
    /**
     * 多类分类评估器
     */
    public static MulticlassClassificationEvaluator getOperator(String labelCol, String predictionCol) {
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator();

        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (predictionCol != null && predictionCol.length() > 0) {
            evaluator.setPredictionCol(predictionCol);
        }

        return evaluator;
    }

    /**
     * 多类分类评估
     */
    public static double evaluate(MulticlassClassificationEvaluator evaluator, Dataset<Row> predictions) {
        return evaluator.evaluate(predictions);
    }

//    public static Metrics evaluate1(MulticlassClassificationEvaluator evaluator, Dataset<Row> predictions){
//        JavaRDD<Tuple2<Object, Object>> predictionAndLabels = predictions
//                .select(evaluator.getPredictionCol(), evaluator.getLabelCol()).javaRDD()
//                .map(r->new Tuple2<>(r.get(0), r.get(1)));
//
//        MulticlassMetrics multiclassMetrics = new MulticlassMetrics(predictionAndLabels.rdd());
//        Metrics metrics = new Metrics();
//        metrics.setF1(multiclassMetrics.weightedFMeasure());
//        metrics.setWeightedRecall(multiclassMetrics.weightedRecall());
//        metrics.setWeightedPrecision(multiclassMetrics.weightedPrecision());
//        metrics.setAccuracy(multiclassMetrics.accuracy());
//        PersistService.saveMetrics(metrics);
//
//        return metrics;
//    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, String labelCol, String predictionCol) {
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator();

        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (predictionCol != null && predictionCol.length() > 0) {
            evaluator.setPredictionCol(predictionCol);
        }

        Dataset<Row> preditions = transformer.transform(testData);

        Metrics metrics = new Metrics();

        evaluator.setMetricName("f1");
        double f1 = evaluator.evaluate(preditions);
        metrics.getIndicator().setF1(f1);

        evaluator.setMetricName("weightedPrecision");
        double weightedPrecision = evaluator.evaluate(preditions);
        metrics.getIndicator().setWeightedPrecision(weightedPrecision);

        evaluator.setMetricName("weightedRecall");
        double weightedRecall = evaluator.evaluate(preditions);
        metrics.getIndicator().setWeightedRecall(weightedRecall);

        evaluator.setMetricName("accuracy");
        double accuracy = evaluator.evaluate(preditions);
        metrics.getIndicator().setAccuracy(accuracy);

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
