/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.CurvePoint;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.RDDOperator;
import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.util.Saveable;
import scala.Tuple2;

public class BinaryClassificationMetricsWrapper
implements RDDOperator {
    public Metrics evaluation(Saveable model, JavaRDD<LabeledPoint> data) {
        JavaRDD scoreAndLabels;
        Metrics metrics = new Metrics();
        if (model instanceof LogisticRegressionModel) {
            LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel)model;
            logisticRegressionModel.clearThreshold();
            scoreAndLabels = data.map((Function & Serializable)d2 -> {
                Double score = logisticRegressionModel.predict(d2.features());
                return new Tuple2((Object)score, (Object)d2.label());
            });
        } else if (model instanceof ClassificationModel) {
            ClassificationModel classificationModel = (ClassificationModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> {
                Double score = classificationModel.predict(d2.features());
                return new Tuple2((Object)score, (Object)d2.label());
            });
        } else if (model instanceof DecisionTreeModel) {
            DecisionTreeModel decisionTreeModel = (DecisionTreeModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> {
                Double score = decisionTreeModel.predict(d2.features());
                return new Tuple2((Object)score, (Object)d2.label());
            });
        } else if (model instanceof TreeEnsembleModel) {
            TreeEnsembleModel treeEnsembleModel = (TreeEnsembleModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> {
                Double score = treeEnsembleModel.predict(d2.features());
                return new Tuple2((Object)score, (Object)d2.label());
            });
        } else {
            String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: LogisticRegressionModel, ClassificationModel, DecisionTreeModel, TreeEnsembleModel";
            throw new IllegalArgumentException(message);
        }
        BinaryClassificationMetrics binaryClassificationMetrics = new BinaryClassificationMetrics(scoreAndLabels.rdd());
        metrics.getIndicator().setAreaUnderPR(binaryClassificationMetrics.areaUnderPR());
        metrics.getIndicator().setAreaUnderROC(binaryClassificationMetrics.areaUnderROC());
        List roc = binaryClassificationMetrics.roc().toJavaRDD().map((Function & Serializable)r2 -> new CurvePoint(Double.parseDouble(r2._1().toString()), Double.parseDouble(r2._2().toString()))).collect();
        metrics.setRoc(roc);
        List pr = binaryClassificationMetrics.pr().toJavaRDD().map((Function & Serializable)r2 -> new CurvePoint(Double.parseDouble(r2._1().toString()), Double.parseDouble(r2._2().toString()))).collect();
        metrics.setPr(pr);
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }
}

