/*
 * Decompiled with CFR 0.152.
 */
package com.datastax.insight.ml.spark.ml.model;

import com.datastax.insight.core.entity.Cache;
import com.datastax.insight.core.entity.Model;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.util.lang.ReflectUtil;
import com.google.common.base.Strings;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/*
 * Duplicate member names - consider using --renamedupmembers true
 */
public class ModelHandler
implements DataSetOperator {
    public static MLWritable load(String modelId) {
        Model model = (Model)PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "getModel", new String[]{String.class.getTypeName()}, new Object[]{modelId});
        return ModelHandler.a(model.getModelClass(), model.getPath());
    }

    private static MLWritable a(String modelClass, String modelPath) {
        return (MLWritable)ReflectUtil.invokeStaticMethod(modelClass, "load", new String[]{String.class.getTypeName()}, modelPath);
    }

    @Deprecated
    public static void save(MLWritable writable, String modelName, String path, boolean replace) throws IOException {
        String modelPath = path.endsWith("/") && modelName.startsWith("/") ? path + modelName.substring(1) : (!path.endsWith("/") && !modelName.startsWith("/") ? path + "/" + modelName : path + modelName);
        if (replace) {
            writable.write().overwrite().save(modelPath);
        } else {
            writable.write().save(modelPath);
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModel", new String[]{Long.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), modelName, writable.getClass().getName(), modelPath});
    }

    public static void save(MLWritable writable, String modelName, boolean replace) throws IOException {
        Long userId = Long.parseLong(Cache.getCache("userId").toString());
        Map<String, String> settings = ModelHandler.a(userId);
        String storeHome = settings.get("store.path");
        String modelHome = settings.get("store.model");
        String modelPath = ModelHandler.a(ModelHandler.a(storeHome, modelHome), modelName);
        if (replace) {
            writable.write().overwrite().save(modelPath);
        } else {
            writable.write().save(modelPath);
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModel", new String[]{Long.class.getTypeName(), Long.class.getTypeName(), Long.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), PersistService.getBatchId(), PersistService.getFlowVersionId(), modelName, writable.getClass().getName(), modelPath});
    }

    public static String predict(String modelId, String feature) {
        MLWritable model = ModelHandler.load(modelId);
        return ModelHandler.a(model, feature);
    }

    public static String[] predict(String modelId, String[] features2) {
        MLWritable model = ModelHandler.load(modelId);
        List<Vector> vectors = Arrays.stream(features2).map(f2 -> {
            List feature = Arrays.stream(f2.split(",")).map(x2 -> Double.valueOf(x2)).collect(Collectors.toList());
            return Vectors.dense((double[])feature.stream().mapToDouble(x2 -> x2).toArray());
        }).collect(Collectors.toList());
        return ModelHandler.a(model, vectors);
    }

    public static String predict(String modelClass, String modelPath, String feature) {
        MLWritable model = ModelHandler.a(modelClass, modelPath);
        return ModelHandler.a(model, feature);
    }

    public static String[] predict(String modelClass, String modelPath, String[] features2) {
        MLWritable model = ModelHandler.a(modelClass, modelPath);
        List<Vector> vectors = Arrays.stream(features2).map(f2 -> {
            List feature = Arrays.stream(f2.split(",")).map(x2 -> Double.valueOf(x2)).collect(Collectors.toList());
            return Vectors.dense((double[])feature.stream().mapToDouble(x2 -> x2).toArray());
        }).collect(Collectors.toList());
        return ModelHandler.a(model, vectors);
    }

    private static String a(MLWritable model, String feature) {
        List features2 = Arrays.stream(feature.split(",")).map(x2 -> Double.valueOf(x2)).collect(Collectors.toList());
        Vector vector = Vectors.dense((double[])features2.stream().mapToDouble(x2 -> x2).toArray());
        return ModelHandler.a(model, Arrays.asList(vector))[0];
    }

    private static String[] a(MLWritable model, List<Vector> features2) {
        if (model instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel)model;
            return (String[])features2.stream().map(f2 -> String.valueOf(predictionModel.predict(f2))).toArray(String[]::new);
        }
        if (model instanceof PipelineModel) {
            Transformer transformer;
            int i2;
            PipelineModel pipelineModel = (PipelineModel)model;
            String predictionCol = null;
            String featuresCol = null;
            String inputCol = null;
            String outputCol = null;
            String[] labels = null;
            for (i2 = pipelineModel.stages().length - 1; i2 >= 0; --i2) {
                transformer = pipelineModel.stages()[i2];
                if (transformer instanceof StringIndexerModel) {
                    labels = ((StringIndexerModel)transformer).labels();
                }
                if (transformer instanceof PredictionModel && Strings.isNullOrEmpty(predictionCol)) {
                    predictionCol = ((PredictionModel)transformer).getPredictionCol();
                }
                if (transformer instanceof HasFeaturesCol && Strings.isNullOrEmpty(featuresCol)) {
                    featuresCol = ((HasFeaturesCol)transformer).getFeaturesCol();
                }
                if (transformer instanceof HasOutputCol) {
                    outputCol = ((HasOutputCol)transformer).getOutputCol();
                    if (transformer instanceof HasInputCol) {
                        if (Strings.isNullOrEmpty(inputCol)) {
                            if (outputCol.equals(featuresCol)) {
                                inputCol = ((HasInputCol)transformer).getInputCol();
                            }
                        } else if (outputCol.equals(inputCol)) {
                            inputCol = ((HasInputCol)transformer).getInputCol();
                        }
                    }
                }
                if (!Strings.isNullOrEmpty(featuresCol) && !Strings.isNullOrEmpty(inputCol) && labels != null) break;
            }
            if (Strings.isNullOrEmpty(inputCol)) {
                for (i2 = pipelineModel.stages().length - 1; i2 >= 0; --i2) {
                    transformer = pipelineModel.stages()[i2];
                    if (!(transformer instanceof PredictionModel)) continue;
                    return (String[])features2.stream().map(f2 -> String.valueOf(((PredictionModel)transformer).predict(f2))).toArray(String[]::new);
                }
            }
            SparkSession spark = SparkSession.builder().getOrCreate();
            NumericAttribute defaultAttribute = NumericAttribute.defaultAttr();
            int featureSize = features2.get(0).size();
            Attribute[] attributes = new Attribute[featureSize];
            for (int i3 = 0; i3 < featureSize; ++i3) {
                attributes[i3] = defaultAttribute.withName("f" + (i3 + 1));
            }
            AttributeGroup group = new AttributeGroup(inputCol, attributes);
            StructType structType = new StructType(new StructField[]{group.toStructField()});
            Dataset dataSet = spark.createDataFrame(features2.stream().map(f2 -> RowFactory.create((Object[])new Object[]{f2})).collect(Collectors.toList()), structType);
            Dataset result = pipelineModel.transform(dataSet);
            if (!Strings.isNullOrEmpty((String)predictionCol)) {
                String finalPredictionCol = predictionCol;
                if (labels == null) {
                    return (String[])result.select(finalPredictionCol, new String[0]).collectAsList().stream().map(r2 -> String.valueOf(r2.getDouble(r2.fieldIndex(finalPredictionCol)))).toArray(String[]::new);
                }
                String[] finalLabels = labels;
                return (String[])result.select(finalPredictionCol, new String[0]).collectAsList().stream().map(r2 -> finalLabels[(int)r2.getDouble(r2.fieldIndex(finalPredictionCol))]).toArray(String[]::new);
            }
        }
        throw new UnsupportedOperationException("predict method is no support for Class:" + model.getClass().getTypeName());
    }

    private static Map<String, String> a(Long userId) {
        Object result = PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "getSystemSettings", new String[]{Long.class.getTypeName()}, new Object[]{userId});
        return (Map)result;
    }

    private static String a(String storeHome, String storePath) {
        if (storeHome.toLowerCase().startsWith("hdfs://")) {
            if (storeHome.endsWith("/") && storePath.startsWith("/")) {
                return storeHome + storePath.substring(1);
            }
            if (!storeHome.endsWith("/") && !storePath.startsWith("/")) {
                return storeHome + "/" + storePath;
            }
            return storeHome + storePath;
        }
        return Paths.get(storeHome, storePath).toString();
    }
}

