package com.datastax.insight.ml.spark.ml.evaluator;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.core.entity.Metrics;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/evaluator/MulticlassClassificationEvaluatorWrapper.class */
public class MulticlassClassificationEvaluatorWrapper implements DataSetOperator {
    public static MulticlassClassificationEvaluator getOperator(String str, String str2) {
        MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator();
        if (str != null && str.length() > 0) {
            multiclassClassificationEvaluator.setLabelCol(str);
        }
        if (str2 != null && str2.length() > 0) {
            multiclassClassificationEvaluator.setPredictionCol(str2);
        }
        return multiclassClassificationEvaluator;
    }

    public static double evaluate(MulticlassClassificationEvaluator multiclassClassificationEvaluator, Dataset<Row> dataset) {
        return multiclassClassificationEvaluator.evaluate(dataset);
    }

    public static Metrics evaluate(Transformer transformer, Dataset<Row> dataset, String str, String str2) {
        MulticlassClassificationEvaluator multiclassClassificationEvaluator = new MulticlassClassificationEvaluator();
        if (str != null && str.length() > 0) {
            multiclassClassificationEvaluator.setLabelCol(str);
        }
        if (str2 != null && str2.length() > 0) {
            multiclassClassificationEvaluator.setPredictionCol(str2);
        }
        Dataset transform = transformer.transform(dataset);
        Metrics metrics = new Metrics();
        multiclassClassificationEvaluator.setMetricName("f1");
        metrics.getIndicator().setF1(Double.valueOf(multiclassClassificationEvaluator.evaluate(transform)));
        multiclassClassificationEvaluator.setMetricName("weightedPrecision");
        metrics.getIndicator().setWeightedPrecision(Double.valueOf(multiclassClassificationEvaluator.evaluate(transform)));
        multiclassClassificationEvaluator.setMetricName("weightedRecall");
        metrics.getIndicator().setWeightedRecall(Double.valueOf(multiclassClassificationEvaluator.evaluate(transform)));
        multiclassClassificationEvaluator.setMetricName("accuracy");
        metrics.getIndicator().setAccuracy(Double.valueOf(multiclassClassificationEvaluator.evaluate(transform)));
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModelMetrics", new String[]{Long.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), JSON.toJSONString(metrics)});
        return metrics;
    }
}
