/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.ConfusionMatrix;
import com.datastax.insight.core.entity.CurvePoint;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.entity.StatisticD;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.a;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class BinaryClassificationEvaluatorWrapper
implements DataSetOperator {
    public static BinaryClassificationEvaluator getOperator(String labelCol, String rawPredictionCol) {
        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
        if (labelCol != null && labelCol.length() > 0) {
            evaluator.setLabelCol(labelCol);
        }
        if (rawPredictionCol != null && rawPredictionCol.length() > 0) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }
        return evaluator;
    }

    public static double evaluate(BinaryClassificationEvaluator evaluator, Dataset<Row> predictions) {
        return evaluator.evaluate(predictions);
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, String labelCol, String rawPredictionCol) {
        Optional<Transformer> lrModel;
        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
        if (!Strings.isNullOrEmpty((String)labelCol)) {
            evaluator.setLabelCol(labelCol);
        }
        if (!Strings.isNullOrEmpty((String)rawPredictionCol)) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }
        Dataset preditions = transformer.transform(testData);
        Metrics metrics = new Metrics();
        evaluator.setMetricName("areaUnderROC");
        double areaUnderROC = evaluator.evaluate(preditions);
        metrics.getIndicator().setAreaUnderROC(areaUnderROC);
        evaluator.setMetricName("areaUnderPR");
        double areaUnderPR = evaluator.evaluate(preditions);
        metrics.getIndicator().setAreaUnderPR(areaUnderPR);
        LogisticRegressionModel model = null;
        if (transformer instanceof LogisticRegressionModel) {
            model = (LogisticRegressionModel)transformer;
        } else if (transformer instanceof PipelineModel && (lrModel = Arrays.stream(((PipelineModel)transformer).stages()).filter(p2 -> p2 instanceof LogisticRegressionModel).findFirst()).isPresent()) {
            model = (LogisticRegressionModel)lrModel.get();
        }
        if (model != null) {
            LogisticRegressionTrainingSummary summary2 = model.summary();
            BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary)summary2;
            metrics.getIndicator().setAreaUnderROC(binarySummary.areaUnderROC());
            metrics.setRoc(binarySummary.roc().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setPr(binarySummary.pr().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> testData, Double threshold, String label, String rawPredictionCol) {
        PredictionModel classifier = null;
        if (transformer instanceof PipelineModel) {
            Optional<Transformer> predictionModel = Arrays.stream(((PipelineModel)transformer).stages()).filter(p2 -> p2 instanceof PredictionModel).findFirst();
            if (predictionModel.isPresent()) {
                classifier = (PredictionModel)predictionModel.get();
            }
        } else {
            classifier = (PredictionModel)transformer;
        }
        if (classifier == null) {
            // empty if block
        }
        String labelCol = classifier.getLabelCol();
        String predictionCol = classifier.getPredictionCol();
        if (threshold == null) {
            threshold = classifier instanceof LogisticRegressionModel ? Double.valueOf(((LogisticRegressionModel)classifier).getThreshold()) : Double.valueOf(0.5);
        }
        if (classifier instanceof LogisticRegressionModel) {
            ((LogisticRegressionModel)classifier).setThreshold(threshold.doubleValue());
        }
        Dataset predictions = transformer.transform(testData);
        Metrics metrics = new Metrics();
        BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator();
        if (!Strings.isNullOrEmpty((String)label)) {
            evaluator.setLabelCol(label);
        }
        if (!Strings.isNullOrEmpty((String)rawPredictionCol)) {
            evaluator.setRawPredictionCol(rawPredictionCol);
        }
        evaluator.setMetricName("areaUnderROC");
        double areaUnderROC = evaluator.evaluate(predictions);
        metrics.getIndicator().setAreaUnderROC(areaUnderROC);
        evaluator.setMetricName("areaUnderPR");
        double areaUnderPR = evaluator.evaluate(predictions);
        metrics.getIndicator().setAreaUnderPR(areaUnderPR);
        long TP = predictions.filter(labelCol + ">=" + threshold + " and " + predictionCol + ">=" + threshold).count();
        long TN = predictions.filter(labelCol + "<" + threshold + " and " + predictionCol + "<" + threshold).count();
        long FP = predictions.filter(labelCol + "<" + threshold + " and " + predictionCol + ">=" + threshold).count();
        long FN = predictions.filter(labelCol + ">=" + threshold + " and " + predictionCol + "<" + threshold).count();
        ConfusionMatrix matrix = new ConfusionMatrix();
        matrix.setTp(TP);
        matrix.setTn(TN);
        matrix.setFp(FP);
        matrix.setFn(TN);
        matrix.setThreshold(threshold);
        metrics.setConfusionMatrix(matrix);
        double PPV = TP + FP == 0L ? 0.0 : (double)TP / (double)(TP + FP);
        metrics.getIndicator().setPrecision(PPV);
        double TPR = TP + FN == 0L ? 0.0 : (double)TP / (double)(TP + FN);
        metrics.getIndicator().setRecall(TPR);
        double f1 = (double)(2L * TP) / (double)(2L * TP + FP + FN);
        metrics.getIndicator().setF1(f1);
        BinaryClassificationEvaluatorWrapper.a(classifier, metrics);
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }

    private static void a(PredictionModel model, Metrics metrics) {
        if (model instanceof LogisticRegressionModel) {
            LogisticRegressionTrainingSummary summary2 = ((LogisticRegressionModel)model).summary();
            a summaryExt = new a((Dataset<Row>)summary2.predictions(), summary2.probabilityCol(), summary2.labelCol(), summary2.featuresCol());
            metrics.getIndicator().setAreaUnderROC(summaryExt.areaUnderROC());
            metrics.setRoc(summaryExt.roc().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setPr(summaryExt.pr().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setLift(summaryExt.f().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setGain(summaryExt.e().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setKsMinus(summaryExt.g().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            metrics.setKsPlus(summaryExt.h().collectAsList().stream().map(row -> new CurvePoint(row.getDouble(0), row.getDouble(1))).collect(Collectors.toList()));
            Dataset dSet = summaryExt.g().join(summaryExt.h(), "reach").selectExpr(new String[]{"reach", "FPR", "TPR", "(TPR - FPR) as D"});
            Row dRow = (Row)dSet.sort(new Column[]{dSet.col("D").desc()}).first();
            StatisticD d2 = new StatisticD();
            d2.setX((Double)dRow.getAs("reach"));
            d2.setD((Double)dRow.getAs("D"));
            d2.setFpr((Double)dRow.getAs("FPR"));
            d2.setTpr((Double)dRow.getAs("TPR"));
            metrics.setD(d2);
        }
    }
}

