package com.datastax.insight.ml.spark.ml.model;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import org.apache.parquet.Strings;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.util.MLWritable;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;


public class ModelExplanation implements DataSetOperator {

    public static void main(String[] args) {

        String str =
                "RandomForestClassificationModel (uid=rfc_ab163b1bdcde) with 10 trees\n" +
                "  Tree 0 (weight 1.0):\n" +
                "    If (feature 552 <= 0.0)\n" +
                "     If (feature 456 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 456 > 0.0)\n" +
                "      Predict: 1.0\n" +
                "    Else (feature 552 > 0.0)\n" +
                "     If (feature 121 in {2.0})\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 121 not in {2.0})\n" +
                "      If (feature 544 <= 227.0)\n" +
                "       Predict: 1.0\n" +
                "      Else (feature 544 > 227.0)\n" +
                "       Predict: 0.0\n" +
                "  Tree 1 (weight 1.0):\n" +
                "    If (feature 463 <= 0.0)\n" +
                "     If (feature 631 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 631 > 0.0)\n" +
                "      If (feature 544 <= 156.0)\n" +
                "       Predict: 1.0\n" +
                "      Else (feature 544 > 156.0)\n" +
                "       Predict: 0.0\n" +
                "    Else (feature 463 > 0.0)\n" +
                "     Predict: 0.0\n" +
                "  Tree 2 (weight 1.0):\n" +
                "    If (feature 385 <= 0.0)\n" +
                "     If (feature 350 <= 0.0)\n" +
                "      Predict: 1.0\n" +
                "     Else (feature 350 > 0.0)\n" +
                "      Predict: 0.0\n" +
                "    Else (feature 385 > 0.0)\n" +
                "     Predict: 1.0\n" +
                "  Tree 3 (weight 1.0):\n" +
                "    If (feature 328 <= 0.0)\n" +
                "     If (feature 261 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 261 > 0.0)\n" +
                "      Predict: 1.0\n" +
                "    Else (feature 328 > 0.0)\n" +
                "     Predict: 1.0\n" +
                "  Tree 4 (weight 1.0):\n" +
                "    If (feature 429 <= 0.0)\n" +
                "     If (feature 358 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 358 > 0.0)\n" +
                "      Predict: 1.0\n" +
                "    Else (feature 429 > 0.0)\n" +
                "     Predict: 1.0\n" +
                "  Tree 5 (weight 1.0):\n" +
                "    If (feature 462 <= 0.0)\n" +
                "     Predict: 1.0\n" +
                "    Else (feature 462 > 0.0)\n" +
                "     Predict: 0.0\n" +
                "  Tree 6 (weight 1.0):\n" +
                "    If (feature 512 <= 0.0)\n" +
                "     If (feature 413 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 413 > 0.0)\n" +
                "      Predict: 1.0\n" +
                "    Else (feature 512 > 0.0)\n" +
                "     Predict: 1.0\n" +
                "  Tree 7 (weight 1.0):\n" +
                "    If (feature 512 <= 0.0)\n" +
                "     If (feature 288 <= 0.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 288 > 0.0)\n" +
                "      Predict: 1.0\n" +
                "    Else (feature 512 > 0.0)\n" +
                "     Predict: 1.0\n" +
                "  Tree 8 (weight 1.0):\n" +
                "    If (feature 462 <= 0.0)\n" +
                "     Predict: 1.0\n" +
                "    Else (feature 462 > 0.0)\n" +
                "     Predict: 0.0\n" +
                "  Tree 9 (weight 1.0):\n" +
                "    If (feature 385 <= 0.0)\n" +
                "     If (feature 578 <= 35.0)\n" +
                "      Predict: 0.0\n" +
                "     Else (feature 578 > 35.0)\n" +
                "      If (feature 605 <= 233.0)\n" +
                "       Predict: 0.0\n" +
                "      Else (feature 605 > 233.0)\n" +
                "       Predict: 1.0\n" +
                "    Else (feature 385 > 0.0)\n" +
                "     Predict: 1.0\n";

        JSONObject jsonObject = analyseAllTrees(str);
        System.out.println(jsonObject);

//        String res = explain("aaa");
//        System.out.println(res);
    }

    public static String explain(MLWritable writable){
        // 如果writable中包含toDebugString方法 则调用该方法 生成的数据转化为json
        try {
            // 如果writable为管道 找到其中的树模型
            if (writable instanceof PipelineModel) {
                Optional<Transformer> treeModel = Arrays.stream(((PipelineModel) writable).stages())
                        .filter(p ->{
                            try {
                                Method tmpMethod = p.getClass().getMethod("toDebugString");
                                if(tmpMethod != null)
                                    return true;
                                else
                                    return false;
                            } catch (NoSuchMethodException e) {
                                return false;
                            }
                        }).findFirst();
                if(treeModel.isPresent()) {
                    writable = (MLWritable)treeModel.get();
                }
            }

            Method method = writable.getClass().getMethod("toDebugString");
            if(method != null){
                String debugStr = method.invoke(writable).toString();
                String debugJson = analyseAllTrees(debugStr).toJSONString();
                // 保存debugJson到数据表中
                PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO",
                        "saveModelDebugJson",
                        new String[]{Long.class.getTypeName(), String.class.getTypeName()},
                        new Object[]{PersistService.getFlowId(), debugJson});
                return debugJson;
            }else {
                System.out.println("no such method!");
                return "no such method!";
            }
        } catch (Exception e) {
            System.out.println("not a tree model");
        }
        System.out.println("invoke error");
        return "invoke error";
    }


    private static JSONObject analyseAllTrees(String debugStr){
        String[] split = debugStr.split("\n");
        ArrayList<String> list = new ArrayList<>();
        JSONArray jsonArray = new JSONArray();
        JSONObject result = new JSONObject();
        int titlePos = 1;
        for(int i=0;i<split.length;i++){
            if(i==0){
                result.put("title",split[i].trim());
            }
            else if(split[i].trim().startsWith("Tree") && list.size() > 0){
                JSONObject jsonObject = new JSONObject();
                jsonObject.put("title",split[titlePos].trim());
                int firstCharPos = getFirstCharPosFromString(list.get(0));// if的位置 找else的位置
                int elsePos = getElseClauseIndex(list, firstCharPos, 1, list.size());
                analyseDebugString(jsonObject,list,0,elsePos,list.size());
                jsonArray.add(jsonObject);
                // 初始化下一棵树的信息
                titlePos = i;
                list.clear();
            }
            else if(!split[i].trim().startsWith("Tree")){
                list.add(split[i]);
            }
        }
        // 处理最后一棵树
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("title",split[titlePos].trim());
        int firstCharPos = getFirstCharPosFromString(list.get(0));// if的位置 找else的位置
        int elsePos = getElseClauseIndex(list, firstCharPos, 1, list.size());
        analyseDebugString(jsonObject,list,0,elsePos,list.size());
        jsonArray.add(jsonObject);


        result.put("num",jsonArray.size());
        result.put("trees",jsonArray);
        return result;
    }

    private static void analyseDebugString(JSONObject jsonObject, List<String> list, int lIndex, int rIndex, int end){
        assert lIndex < rIndex && rIndex < end ;
        // 设置根的值
        jsonObject.put("value",list.get(lIndex).trim().replace("If","").replace("(","").replace(")","").trim());
        // 处理左子树
        if(list.get(lIndex + 1).trim().startsWith("Predict")){
            // 左子树为叶子节点
            JSONObject leftJson = new JSONObject();
            leftJson.put("value",list.get(lIndex + 1).trim().replace("If","").replace("(","").replace(")","").trim());
            jsonObject.put("left",leftJson);
        }else if(list.get(lIndex + 1).trim().startsWith("If")){
            // 左子树还是棵树
            JSONObject leftJson = new JSONObject();
            jsonObject.put("left",leftJson);
            int firstCharPos = getFirstCharPosFromString(list.get(lIndex + 1));// if的位置 找else的位置
            int elsePos = getElseClauseIndex(list, firstCharPos, lIndex + 2, rIndex);
            analyseDebugString(leftJson,list,lIndex + 1,elsePos, rIndex);
        }
        // 处理右子树
        if(list.get(rIndex + 1).trim().startsWith("Predict")){
            // 右子树为叶子节点
            JSONObject rightJson = new JSONObject();
            rightJson.put("value",list.get(rIndex + 1).trim().replace("If","").replace("(","").replace(")","").trim());
            jsonObject.put("right",rightJson);
        }else if(list.get(rIndex + 1).trim().startsWith("If")){
            // 右子树还是棵树
            JSONObject rightJson = new JSONObject();
            jsonObject.put("right",rightJson);
            int firstCharPos = getFirstCharPosFromString(list.get(rIndex + 1));// if的位置 找else的位置
            int elsePos = getElseClauseIndex(list, firstCharPos, rIndex + 2, end);
            analyseDebugString(rightJson,list,rIndex + 1,elsePos, end);
        }
    }


    private static int getFirstCharPosFromString(String line){
        int index = -1;
        if(!Strings.isNullOrEmpty(line)){
            int i=0;
            for (;i<line.length();i++){
                if(!(line.charAt(i) == ' '))
                    break;
            }
            index = i;
        }
        return index;
    }

    private static int getElseClauseIndex(List<String> list,int condition,int start,int end){// [start,end)
        int pos = -1;
        for(int i=start;i<end;i++){
            if(getFirstCharPosFromString(list.get(i)) == condition){
                pos = i;
                break;
            }
        }
        return pos;
    }

}


