package com.datastax.insight.ml.spark.ml.model;

import com.datastax.insight.core.entity.Cache;
import com.datastax.insight.core.entity.Model;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.service.PersistService;
import com.datastax.util.lang.ReflectUtil;
import com.google.common.base.Strings;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 模型管理
 */
public class ModelHandler implements DataSetOperator {

    /**
     * 模型加载
     */
    public static MLWritable load(String modelId){
        Model model = (Model) PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "getModel",
                new String[]{String.class.getTypeName()},
                new Object[]{modelId});

        return innerLoad(model.getModelClass(), model.getPath());
    }

    private static MLWritable innerLoad(String modelClass, String modelPath){

        return (MLWritable) ReflectUtil.invokeStaticMethod(modelClass, "load",
                new String[] { String.class.getTypeName()},
                new Object[] { modelPath});
    }

    @Deprecated
    public static void save(MLWritable writable, String modelName, String path, boolean replace) throws IOException {
        String modelPath;

        if(path.endsWith("/") && modelName.startsWith("/")) {
            modelPath = path + modelName.substring(1);
        } else if (!path.endsWith("/") && !modelName.startsWith("/")) {
            modelPath = path + "/" + modelName;
        } else {
            modelPath = path + modelName;
        }

        if(replace) {
            writable.write().overwrite().save(modelPath);
        } else {
            writable.write().save(modelPath);
        }

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModel",
                new String[]{Long.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(), modelName, writable.getClass().getName(), modelPath});
    }

    /**
     * 模型保存
     */
    public static void save(MLWritable writable, String modelName, boolean replace) throws IOException {

//        String modelPath;
//
//        if(path.endsWith("/") && modelName.startsWith("/")) {
//            modelPath = path + modelName.substring(1);
//        } else if (!path.endsWith("/") && !modelName.startsWith("/")) {
//            modelPath = path + "/" + modelName;
//        } else {
//            modelPath = path + modelName;
//        }

        Long userId = Long.parseLong(Cache.getCache("userId").toString());
        Map<String, String> settings = getSettings(userId);
        String storeHome = settings.get("store.path");
        String modelHome = settings.get("store.model");

        String modelPath = getStorePath(getStorePath(storeHome, modelHome), modelName);

        if(replace) {
            writable.write().overwrite().save(modelPath);
        } else {
            writable.write().save(modelPath);
        }

        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "saveModel",
                new String[]{Long.class.getTypeName(),
                        Long.class.getTypeName(),
                        Long.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName(),
                        String.class.getTypeName()},
                new Object[]{PersistService.getFlowId(),
                        PersistService.getBatchId(),
                        PersistService.getFlowVersionId(),
                        modelName,
                        writable.getClass().getName(),
                        modelPath});
    }

    public static String predict(String modelId, String feature) {
        MLWritable model = load(modelId);
        return predict(model, feature);
    }

    public static String[] predict(String modelId, String[] features) {
        MLWritable model = load(modelId);
        List<Vector> vectors = Arrays.stream(features).map(f->{
            List<Double> feature = Arrays.stream(f.split(",")).map(x->Double.valueOf(x)).collect(Collectors.toList());
            return Vectors.dense(feature.stream().mapToDouble(x->x.doubleValue()).toArray());
        }).collect(Collectors.toList());
        return predict(model, vectors);
    }

    public static String predict(String modelClass, String modelPath, String feature) {
        MLWritable model = innerLoad(modelClass, modelPath);
        return predict(model, feature);
    }

    public static String[] predict(String modelClass, String modelPath, String[] features) {
        MLWritable model = innerLoad(modelClass, modelPath);
        List<Vector> vectors = Arrays.stream(features).map(f->{
            List<Double> feature = Arrays.stream(f.split(",")).map(x->Double.valueOf(x)).collect(Collectors.toList());
            return Vectors.dense(feature.stream().mapToDouble(x->x.doubleValue()).toArray());
        }).collect(Collectors.toList());
        return predict(model, vectors);
    }

    private static String predict(MLWritable model, String feature) {
        List<Double> features = Arrays.stream(feature.split(",")).map(x->Double.valueOf(x)).collect(Collectors.toList());
        Vector vector = Vectors.dense(features.stream().mapToDouble(x->x.doubleValue()).toArray());
        return predict(model,  Arrays.asList(vector))[0];
    }

    private static String[] predict(MLWritable model, List<Vector> features) {

        if(model instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel) model;
            return features.stream().map(f->String.valueOf(predictionModel.predict(f))).toArray(String[]::new);
        } else if (model instanceof PipelineModel) {

            PipelineModel pipelineModel = (PipelineModel) model;

            String predictionCol = null;
            String featuresCol = null;
            String inputCol = null;
            String outputCol = null;
            String[] labels = null;

//            Optional<Transformer> predictionModel = Arrays.stream(pipelineModel.stages())
//                    .filter(p -> p instanceof PredictionModel).findFirst();
//            if (predictionModel.isPresent()) {
//                predictionCol = ((PredictionModel) predictionModel.get()).getPredictionCol();
//                featuresCol = ((PredictionModel) predictionModel.get()).getFeaturesCol();
//            }

            for(int i = pipelineModel.stages().length - 1; i >= 0; i--) {

                Transformer transformer = pipelineModel.stages()[i];

                if(transformer instanceof StringIndexerModel) {
                    labels = ((StringIndexerModel)transformer).labels();
                }

                if(transformer instanceof PredictionModel && Strings.isNullOrEmpty(predictionCol)) {
                    predictionCol = ((PredictionModel)transformer).getPredictionCol();
                }

                if(transformer instanceof HasFeaturesCol && Strings.isNullOrEmpty(featuresCol)) {
                    featuresCol = ((HasFeaturesCol)transformer).getFeaturesCol();
                }

                if(transformer instanceof HasOutputCol) {
                    outputCol = ((HasOutputCol)transformer).getOutputCol();

                    if(transformer instanceof HasInputCol) {
                        if (Strings.isNullOrEmpty(inputCol)) {
                            if (outputCol.equals(featuresCol)) {
                                inputCol = ((HasInputCol) transformer).getInputCol();
                            }
                        } else {
                            if (outputCol.equals(inputCol)) {
                                inputCol = ((HasInputCol) transformer).getInputCol();
                            }
                        }
                    }
                }

                if(!Strings.isNullOrEmpty(featuresCol) && !Strings.isNullOrEmpty(inputCol) && labels != null) {
                    break;
                }
            }

            //todo andy 因为通道问题，代码有问题，为演示只能如此。fuck
            if (Strings.isNullOrEmpty(inputCol)) {
                for(int i = pipelineModel.stages().length - 1; i >= 0; i--) {
                    Transformer transformer = pipelineModel.stages()[i];
                    if(transformer instanceof PredictionModel) {
                        return features.stream().map(f->String.valueOf(((PredictionModel) transformer).predict(f))).toArray(String[]::new);
                    }
                }
            }

//            for (Transformer transformer : pipelineModel.stages()) {
//
//                HasInputCol inputModel = null;
//
//                if(transformer instanceof HasInputCol) {
//                    inputModel = (HasInputCol)transformer;
//                }
//
//                if(inputModel != null && Strings.isNullOrEmpty(featuresCol)) {
//                    featuresCol = inputModel.getInputCol();
//                }
//
//                PredictionModel predictionModel = null;
//
//                if(transformer instanceof PredictionModel) {
//                    predictionModel = (PredictionModel)transformer;
//                }
//
//                if(predictionModel != null && Strings.isNullOrEmpty(predictionCol)) {
//                    predictionCol = predictionModel.getPredictionCol();
//                }
//
//                if(predictionModel != null && Strings.isNullOrEmpty(featuresCol)) {
//                    featuresCol = predictionModel.getFeaturesCol();
//                }
//
//                if(!Strings.isNullOrEmpty(featuresCol) && !Strings.isNullOrEmpty(predictionCol)) {
//                    break;
//                }
//            }

            SparkSession spark = SparkSession
                    .builder()
                    .getOrCreate();

            NumericAttribute defaultAttribute = NumericAttribute.defaultAttr();
            int featureSize = features.get(0).size();
            Attribute[] attributes = new Attribute[featureSize];

            for(int i = 0; i < featureSize; i++) {
                attributes[i] = defaultAttribute.withName("f" + (i + 1));
            }

            AttributeGroup group = new AttributeGroup(inputCol, attributes);
            StructType structType = new StructType(new StructField[] { group.toStructField() });

            Dataset<Row> dataSet = spark.createDataFrame(features.stream().map(f->RowFactory.create(f)).collect(Collectors.toList()), structType);

            Dataset<Row> result = pipelineModel.transform(dataSet);

            if (!Strings.isNullOrEmpty(predictionCol)) {
                String finalPredictionCol = predictionCol;

                if(labels == null) {
                    return result.select(finalPredictionCol)
                            .collectAsList().stream()
                            .map(r -> String.valueOf(r.getDouble(r.fieldIndex(finalPredictionCol)))).toArray(String[]::new);
                } else {
                    String[] finalLabels = labels;
                    return result.select(finalPredictionCol)
                            .collectAsList().stream()
                            .map(r -> finalLabels[(int)r.getDouble(r.fieldIndex(finalPredictionCol))]).toArray(String[]::new);
                }
            }
        }

        throw new UnsupportedOperationException("predict method is no support for Class:" + model.getClass().getTypeName());
    }

//    private static double[] predict(MLWritable model, String[] features) {
//        Vector[] vectors = Arrays.stream(features).map(d->Vectors.dense(Arrays.stream(d.split(",")).mapToDouble(x->Double.parseDouble(x)).toArray())).toArray(Vector[]::new);
//        return predict(model, vectors);
//    }

//    private static double[] predict(MLWritable model, Vector[] features) {
//        return Arrays.stream(features).mapToDouble(v->predict(model, v)).toArray();
//    }

    private static Map<String, String> getSettings(Long userId) {
        Object result = PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                "getSystemSettings",
                new String[]{ Long.class.getTypeName() },
                new Object[]{ userId });

        return (Map<String, String>) result;
    }

    private static String getStorePath(String storeHome, String storePath) {

        if(storeHome.toLowerCase().startsWith("hdfs://")) {
            if(storeHome.endsWith("/") && storePath.startsWith("/")) {
                return storeHome + storePath.substring(1);
            } else if (!storeHome.endsWith("/") && !storePath.startsWith("/")) {
                return storeHome + "/" + storePath;
            } else {
                return storeHome + storePath;
            }
        } else {
            return Paths.get(storeHome, storePath).toString();
        }
    }

//    private static Model getModel(int id) {
//        List<Model> data = (List<Model>) Cache.getCache("models");
//        if(data != null) {
//            return data.stream().filter(d->d.getId() == id).findFirst().orElse(null);
//        }
//        return null;
//    }

//    public static void main(String[] args) {
//        SparkSession spark = SparkSession
//                .builder()
//                .appName("datastax-insight")
//                .master("local[*]")
//                .getOrCreate();
//        System.out.println(predict("org.apache.spark.ml.PipelineModel", "hdfs://datastax-cdh01:8020/user/datastax/store/model/gamma", "90.5299,17.8721,3.033,0.2122,0.108,-57.417,84.6189,14.9662,2.039,272.038"));
//        spark.close();
//    }
}
