package com.datastax.insight.ml.spark.mllib.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.spec.RDDOperator;
import com.datastax.insight.core.service.PersistService;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.evaluation.MultilabelMetrics;
import scala.Tuple2;

/**
 * Created by huangping on 17-1-22.
 */
public class MultilabelMetricsWrapper implements RDDOperator {

    public Metrics evaluation(JavaRDD<Tuple2<double[], double[]>> scoreAndLabels) {

        MultilabelMetrics multilabelMetrics = new MultilabelMetrics(scoreAndLabels.rdd());
        Metrics metrics = new Metrics();

        //summary stats
        metrics.getIndicator().setRecall(multilabelMetrics.recall());
        metrics.getIndicator().setPrecision(multilabelMetrics.precision());
        metrics.getIndicator().setF1(multilabelMetrics.f1Measure());
        metrics.getIndicator().setAccuracy(multilabelMetrics.accuracy());

        //micro stats
        metrics.getIndicator().setMicroRecall(multilabelMetrics.microRecall());
        metrics.getIndicator().setMicroPrecision(multilabelMetrics.microPrecision());
        metrics.getIndicator().setMicroF1Measure(multilabelMetrics.microF1Measure());

        //hamming loss
        metrics.getIndicator().setHammingLoss(multilabelMetrics.hammingLoss());

        //subset accuracy
        metrics.getIndicator().setSubsetAccuracy(multilabelMetrics.subsetAccuracy());

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModelMetrics",
                new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});

        return metrics;
    }
}
