package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.ConfusionMatrix;
import com.datastax.insight.core.entity.CurvePoint;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.entity.StatisticD;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.ml.classification.a;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/evaluator/BinaryClassificationEvaluatorWrapper.class */
public class BinaryClassificationEvaluatorWrapper implements DataSetOperator {
    public static BinaryClassificationEvaluator getOperator(String str, String str2) {
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator();
        if (str != null && str.length() > 0) {
            binaryClassificationEvaluator.setLabelCol(str);
        }
        if (str2 != null && str2.length() > 0) {
            binaryClassificationEvaluator.setRawPredictionCol(str2);
        }
        return binaryClassificationEvaluator;
    }

    public static double evaluate(BinaryClassificationEvaluator binaryClassificationEvaluator, Dataset<Row> dataset) {
        return binaryClassificationEvaluator.evaluate(dataset);
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> dataset, String str, String str2) {
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator();
        if (!Strings.isNullOrEmpty(str)) {
            binaryClassificationEvaluator.setLabelCol(str);
        }
        if (!Strings.isNullOrEmpty(str2)) {
            binaryClassificationEvaluator.setRawPredictionCol(str2);
        }
        Dataset transform = transformer.transform(dataset);
        Metrics metrics = new Metrics();
        binaryClassificationEvaluator.setMetricName("areaUnderROC");
        metrics.getIndicator().setAreaUnderROC(Double.valueOf(binaryClassificationEvaluator.evaluate(transform)));
        binaryClassificationEvaluator.setMetricName("areaUnderPR");
        metrics.getIndicator().setAreaUnderPR(Double.valueOf(binaryClassificationEvaluator.evaluate(transform)));
        LogisticRegressionModel logisticRegressionModel = null;
        if (transformer instanceof LogisticRegressionModel) {
            logisticRegressionModel = (LogisticRegressionModel) transformer;
        } else if (transformer instanceof PipelineModel) {
            Optional findFirst = Arrays.stream(((PipelineModel) transformer).stages()).filter(transformer2 -> {
                return transformer2 instanceof LogisticRegressionModel;
            }).findFirst();
            if (findFirst.isPresent()) {
                logisticRegressionModel = (LogisticRegressionModel) findFirst.get();
            }
        }
        if (logisticRegressionModel != null) {
            BinaryLogisticRegressionSummary summary = logisticRegressionModel.summary();
            metrics.getIndicator().setAreaUnderROC(Double.valueOf(summary.areaUnderROC()));
            metrics.setRoc((List) summary.roc().collectAsList().stream().map(row -> {
                return new CurvePoint(Double.valueOf(row.getDouble(0)), Double.valueOf(row.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setPr((List) summary.pr().collectAsList().stream().map(row2 -> {
                return new CurvePoint(Double.valueOf(row2.getDouble(0)), Double.valueOf(row2.getDouble(1)));
            }).collect(Collectors.toList()));
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> dataset, Double d, String str, String str2) {
        PredictionModel predictionModel = null;
        if (transformer instanceof PipelineModel) {
            Optional findFirst = Arrays.stream(((PipelineModel) transformer).stages()).filter(transformer2 -> {
                return transformer2 instanceof PredictionModel;
            }).findFirst();
            if (findFirst.isPresent()) {
                predictionModel = (PredictionModel) findFirst.get();
            }
        } else {
            predictionModel = (PredictionModel) transformer;
        }
        if (predictionModel == null) {
        }
        String labelCol = predictionModel.getLabelCol();
        String predictionCol = predictionModel.getPredictionCol();
        if (d == null) {
            d = predictionModel instanceof LogisticRegressionModel ? Double.valueOf(((LogisticRegressionModel) predictionModel).getThreshold()) : Double.valueOf(0.5d);
        }
        if (predictionModel instanceof LogisticRegressionModel) {
            ((LogisticRegressionModel) predictionModel).setThreshold(d.doubleValue());
        }
        Dataset transform = transformer.transform(dataset);
        Metrics metrics = new Metrics();
        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator();
        if (!Strings.isNullOrEmpty(str)) {
            binaryClassificationEvaluator.setLabelCol(str);
        }
        if (!Strings.isNullOrEmpty(str2)) {
            binaryClassificationEvaluator.setRawPredictionCol(str2);
        }
        binaryClassificationEvaluator.setMetricName("areaUnderROC");
        metrics.getIndicator().setAreaUnderROC(Double.valueOf(binaryClassificationEvaluator.evaluate(transform)));
        binaryClassificationEvaluator.setMetricName("areaUnderPR");
        metrics.getIndicator().setAreaUnderPR(Double.valueOf(binaryClassificationEvaluator.evaluate(transform)));
        long count = transform.filter(labelCol + ">=" + d + " and " + predictionCol + ">=" + d).count();
        long count2 = transform.filter(labelCol + "<" + d + " and " + predictionCol + "<" + d).count();
        long count3 = transform.filter(labelCol + "<" + d + " and " + predictionCol + ">=" + d).count();
        long count4 = transform.filter(labelCol + ">=" + d + " and " + predictionCol + "<" + d).count();
        ConfusionMatrix confusionMatrix = new ConfusionMatrix();
        confusionMatrix.setTp(Long.valueOf(count));
        confusionMatrix.setTn(Long.valueOf(count2));
        confusionMatrix.setFp(Long.valueOf(count3));
        confusionMatrix.setFn(Long.valueOf(count2));
        confusionMatrix.setThreshold(d);
        metrics.setConfusionMatrix(confusionMatrix);
        metrics.getIndicator().setPrecision(Double.valueOf(count + count3 == 0 ? 0.0d : count / (count + count3)));
        metrics.getIndicator().setRecall(Double.valueOf(count + count4 == 0 ? 0.0d : count / (count + count4)));
        metrics.getIndicator().setF1(Double.valueOf((2 * count) / (((2 * count) + count3) + count4)));
        a(predictionModel, metrics);
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }

    private static void a(PredictionModel predictionModel, Metrics metrics) {
        if (predictionModel instanceof LogisticRegressionModel) {
            LogisticRegressionTrainingSummary summary = ((LogisticRegressionModel) predictionModel).summary();
            a aVar = new a(summary.predictions(), summary.probabilityCol(), summary.labelCol(), summary.featuresCol());
            metrics.getIndicator().setAreaUnderROC(Double.valueOf(aVar.areaUnderROC()));
            metrics.setRoc((List) aVar.roc().collectAsList().stream().map(row -> {
                return new CurvePoint(Double.valueOf(row.getDouble(0)), Double.valueOf(row.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setPr((List) aVar.pr().collectAsList().stream().map(row2 -> {
                return new CurvePoint(Double.valueOf(row2.getDouble(0)), Double.valueOf(row2.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setLift((List) aVar.f().collectAsList().stream().map(row3 -> {
                return new CurvePoint(Double.valueOf(row3.getDouble(0)), Double.valueOf(row3.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setGain((List) aVar.e().collectAsList().stream().map(row4 -> {
                return new CurvePoint(Double.valueOf(row4.getDouble(0)), Double.valueOf(row4.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setKsMinus((List) aVar.g().collectAsList().stream().map(row5 -> {
                return new CurvePoint(Double.valueOf(row5.getDouble(0)), Double.valueOf(row5.getDouble(1)));
            }).collect(Collectors.toList()));
            metrics.setKsPlus((List) aVar.h().collectAsList().stream().map(row6 -> {
                return new CurvePoint(Double.valueOf(row6.getDouble(0)), Double.valueOf(row6.getDouble(1)));
            }).collect(Collectors.toList()));
            Dataset selectExpr = aVar.g().join(aVar.h(), "reach").selectExpr(new String[]{"reach", "FPR", "TPR", "(TPR - FPR) as D"});
            Row row7 = (Row) selectExpr.sort(new Column[]{selectExpr.col("D").desc()}).first();
            StatisticD statisticD = new StatisticD();
            statisticD.setX(((Double) row7.getAs("reach")).doubleValue());
            statisticD.setD(((Double) row7.getAs("D")).doubleValue());
            statisticD.setFpr(((Double) row7.getAs("FPR")).doubleValue());
            statisticD.setTpr(((Double) row7.getAs("TPR")).doubleValue());
            metrics.setD(statisticD);
        }
    }
}
