/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.RDDOperator;
import java.io.Serializable;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.evaluation.RegressionMetrics;
import org.apache.spark.mllib.regression.IsotonicRegressionModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.RegressionModel;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.util.Saveable;
import scala.Tuple2;
import scala.Tuple3;

public class RegressionMetricsWrapper
implements RDDOperator {
    public Metrics evaluation4Isotonic(Saveable model, JavaRDD<Tuple3<Double, Double, Double>> data) {
        JavaRDD scoreAndLabels = null;
        if (!(model instanceof IsotonicRegressionModel)) {
            String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: IsotonicRegressionModel";
            throw new IllegalArgumentException(message);
        }
        IsotonicRegressionModel isotonicRegressionModel = (IsotonicRegressionModel)model;
        scoreAndLabels = data.map((Function & Serializable)d2 -> new Tuple2((Object)isotonicRegressionModel.predict(((Double)d2._2()).doubleValue()), d2._1()));
        return this.evaluation((JavaRDD<Tuple2<Object, Object>>)scoreAndLabels);
    }

    public Metrics evaluation(Saveable model, JavaRDD<LabeledPoint> data) {
        JavaRDD scoreAndLabels = null;
        if (model instanceof LogisticRegressionModel) {
            LogisticRegressionModel realModel = (LogisticRegressionModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> new Tuple2((Object)realModel.predict(d2.features()), (Object)d2.label()));
        } else if (model instanceof RegressionModel) {
            RegressionModel realModel = (RegressionModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> new Tuple2((Object)realModel.predict(d2.features()), (Object)d2.label()));
        } else if (model instanceof DecisionTreeModel) {
            DecisionTreeModel realModel = (DecisionTreeModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> new Tuple2((Object)realModel.predict(d2.features()), (Object)d2.label()));
        } else if (model instanceof TreeEnsembleModel) {
            TreeEnsembleModel realModel = (TreeEnsembleModel)model;
            scoreAndLabels = data.map((Function & Serializable)d2 -> new Tuple2((Object)realModel.predict(d2.features()), (Object)d2.label()));
        } else {
            String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: LogisticRegressionModel, RegressionModel, ClassificationModel, DecisionTreeModel, TreeEnsembleModel, IsotonicRegressionModel";
            throw new IllegalArgumentException(message);
        }
        return this.evaluation((JavaRDD<Tuple2<Object, Object>>)scoreAndLabels);
    }

    private Metrics evaluation(JavaRDD<Tuple2<Object, Object>> scoreAndLabels) {
        Metrics metrics = new Metrics();
        RegressionMetrics regressionMetrics = new RegressionMetrics(scoreAndLabels.rdd());
        metrics.getIndicator().setMse(regressionMetrics.meanSquaredError());
        metrics.getIndicator().setRmse(regressionMetrics.rootMeanSquaredError());
        metrics.getIndicator().setMae(regressionMetrics.meanAbsoluteError());
        metrics.getIndicator().setR2(regressionMetrics.r2());
        metrics.getIndicator().setExplainedVariance(regressionMetrics.explainedVariance());
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }
}

