package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.TreeEnsembleModel;
import org.apache.spark.mllib.util.Saveable;
import scala.Tuple2;

public class MulticlassMetricsWrapper implements RDDOperator {

    public Metrics evaluation(Saveable model, JavaRDD<LabeledPoint> data) {

        JavaRDD<Tuple2<Object, Object>> scoreAndLabels = null;

        if (model instanceof ClassificationModel) {
            ClassificationModel realModel = (ClassificationModel) model;
            scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
        } else if (model instanceof DecisionTreeModel) {
            DecisionTreeModel realModel = (DecisionTreeModel) model;
            scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
        } else if (model instanceof TreeEnsembleModel) {
            TreeEnsembleModel realModel = (TreeEnsembleModel) model;
            scoreAndLabels = data.map(d -> new Tuple2<>(realModel.predict(d.features()), d.label()));
        } else {
            String message = "[" + model.getClass().getTypeName() + "] is not supported, currently supports: ClassificationModel, DecisionTreeModel, TreeEnsembleModel";
            throw new IllegalArgumentException(message);
        }

        Metrics metrics = new Metrics();
        MulticlassMetrics multiclassMetrics = new MulticlassMetrics(scoreAndLabels.rdd());

        metrics.getIndicator().setPrecision(multiclassMetrics.precision());
        metrics.getIndicator().setRecall(multiclassMetrics.recall());
        metrics.getIndicator().setfMeasure(multiclassMetrics.fMeasure());
        metrics.getIndicator().setAccuracy(multiclassMetrics.accuracy());
        metrics.getIndicator().setWeightedPrecision(multiclassMetrics.weightedPrecision());
        metrics.getIndicator().setWeightedRecall(multiclassMetrics.weightedRecall());
        metrics.getIndicator().setWeightedFMeasure(multiclassMetrics.weightedFMeasure());
        metrics.getIndicator().setWeightedTruePositiveRate(multiclassMetrics.weightedTruePositiveRate());
        metrics.getIndicator().setWeightedFalsePositiveRate(multiclassMetrics.weightedFalsePositiveRate());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
