package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.ConfusionMatrix;
import com.datastax.insight.core.entity.CurvePoint;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.entity.StatisticD;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.service.PersistService;
import com.google.common.base.Strings;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummaryExt;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Collectors;

/**
 * 二分类评估
 */
public class BinaryClassificationEvaluatorWrapper implements DataSetOperator {
    /**
     * 二分类评估器
     */
    public static BinaryClassificationEvaluator getOperator(String labelCol, String rawPredictionCol) {
        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();

        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (rawPredictionCol != null && rawPredictionCol.length() > 0) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }

        return evaluator;
    }

    /**
     * 二分类评估
     */
    public static double evaluate(BinaryClassificationEvaluator evaluator, Dataset<Row> predictions) {
        return evaluator.evaluate(predictions);
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, String labelCol, String rawPredictionCol) {
        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();

        if (!Strings.isNullOrEmpty(labelCol)) {
            evaluator.setLabelCol(labelCol);
        }
        if (!Strings.isNullOrEmpty(rawPredictionCol)) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }

        Dataset<Row> preditions = transformer.transform(testData);
        Metrics metrics = new Metrics();

        evaluator.setMetricName("areaUnderROC");
        double areaUnderROC = evaluator.evaluate(preditions);
        metrics.getIndicator().setAreaUnderROC(areaUnderROC);

        evaluator.setMetricName("areaUnderPR");
        double areaUnderPR = evaluator.evaluate(preditions);
        metrics.getIndicator().setAreaUnderPR(areaUnderPR);

        LogisticRegressionModel model = null;
        if(transformer instanceof LogisticRegressionModel) {
            model = (LogisticRegressionModel) transformer;
        } else if (transformer instanceof PipelineModel) {
            Optional<Transformer> lrModel = Arrays.stream(((PipelineModel) transformer).stages())
                    .filter(p -> p instanceof LogisticRegressionModel).findFirst();
            if(lrModel.isPresent()) {
                model = (LogisticRegressionModel)lrModel.get();
            }
        }

        if(model != null) {
            LogisticRegressionTrainingSummary summary = model.summary();
            BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) summary;
            metrics.getIndicator().setAreaUnderROC(binarySummary.areaUnderROC());
            metrics.setRoc(binarySummary.roc().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setPr(binarySummary.pr().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
        }

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }

//    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, String label, String rawPredictionCol) {
//        return evaluate(transformer, testData, 0.5,label,rawPredictionCol);
//    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, Double threshold, String label, String rawPredictionCol) {

        PredictionModel classifier = null;

        if (transformer instanceof PipelineModel) {
            Optional<Transformer> predictionModel = Arrays.stream(((PipelineModel) transformer).stages())
                    .filter(p -> p instanceof PredictionModel).findFirst();
            if (predictionModel.isPresent()) {
                classifier = (PredictionModel) predictionModel.get();
            }
        } else {
            classifier = (PredictionModel) transformer;
        }

        if(classifier == null) {
            //TODO: throw Exception
        }

        String labelCol = classifier.getLabelCol();
        String predictionCol = classifier.getPredictionCol();

        if(threshold == null) {
            if(classifier instanceof LogisticRegressionModel) {
                threshold = ((LogisticRegressionModel) classifier).getThreshold();
            } else {
                threshold = 0.5;
            }
        }

        if(classifier instanceof LogisticRegressionModel) {
            ((LogisticRegressionModel) classifier).setThreshold(threshold);
        }

        Dataset<Row> predictions = transformer.transform(testData);

        Metrics metrics = new Metrics();

        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
        if (!Strings.isNullOrEmpty(label)) {
            evaluator.setLabelCol(label);
        }
        if (!Strings.isNullOrEmpty(rawPredictionCol)) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }
        evaluator.setMetricName("areaUnderROC");
        double areaUnderROC = evaluator.evaluate(predictions);
        metrics.getIndicator().setAreaUnderROC(areaUnderROC);

        evaluator.setMetricName("areaUnderPR");
        double areaUnderPR = evaluator.evaluate(predictions);
        metrics.getIndicator().setAreaUnderPR(areaUnderPR);

        //True Positive (TP) - label is positive and prediction is also positive
        long TP = predictions.filter(labelCol + ">=" + threshold + " and " + predictionCol + ">=" + threshold).count();

        //True Negative (TN) - label is negative and prediction is also negative
        long TN = predictions.filter(labelCol + "<" + threshold + " and " + predictionCol + "<" + threshold).count();

        //False Positive (FP) - label is negative but prediction is positive
        long FP = predictions.filter(labelCol + "<" + threshold + " and " + predictionCol + ">=" + threshold).count();

        //False Negative (FN) - label is positive but prediction is negative
        long FN = predictions.filter(labelCol + ">=" + threshold + " and " + predictionCol + "<" + threshold).count();

        ConfusionMatrix matrix = new ConfusionMatrix();
        matrix.setTp(TP);
        matrix.setTn(TN);
        matrix.setFp(FP);
        matrix.setFn(TN);
        matrix.setThreshold(threshold);
        metrics.setConfusionMatrix(matrix);

        double PPV = (TP + FP) == 0 ? 0 : TP / (double)(TP + FP);
        metrics.getIndicator().setPrecision(PPV);

        double TPR = (TP + FN) == 0 ? 0 : TP / (double)(TP + FN);
        metrics.getIndicator().setRecall(TPR);

        double f1 = 2 * TP / (double)(2 * TP + FP + FN);
        metrics.getIndicator().setF1(f1);

        curveData(classifier, metrics);

        // TODO 从缓存中获取flowId 多用户提交流程时会有问题
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }

    private static void curveData(PredictionModel model, Metrics metrics) {

        if (model instanceof LogisticRegressionModel) {
            LogisticRegressionTrainingSummary summary = ((LogisticRegressionModel) model).summary();

            BinaryLogisticRegressionSummaryExt summaryExt = new BinaryLogisticRegressionSummaryExt(summary.predictions(),
                    summary.probabilityCol(), summary.labelCol(), summary.featuresCol());

            metrics.getIndicator().setAreaUnderROC(summaryExt.areaUnderROC());
            metrics.setRoc(summaryExt.roc().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setPr(summaryExt.pr().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setLift(summaryExt.lift().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setGain(summaryExt.gain().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setKsMinus(summaryExt.ksMinus().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));
            metrics.setKsPlus(summaryExt.ksPlus().collectAsList().stream()
                    .map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1)))
                    .collect(Collectors.toList()));

            Dataset<Row> dSet = summaryExt.ksMinus().join(summaryExt.ksPlus(), "reach")
                    .selectExpr("reach", "FPR", "TPR", "(TPR - FPR) as D");

            Row dRow = dSet.sort(dSet.col("D").desc()).first();

            StatisticD d = new StatisticD();
            d.setX(dRow.getAs("reach"));
            d.setD(dRow.getAs("D"));
            d.setFpr(dRow.getAs("FPR"));
            d.setTpr(dRow.getAs("TPR"));
            metrics.setD(d);
        }
    }
}
