package com.datastax.data.prepare.spark.dataset;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.spark.dataset.params.FilterSection;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import com.datastax.data.prepare.util.SharedMethods;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Date;
import java.sql.Timestamp;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;


public class FilterOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(FilterOperator.class);

    @InsightComponent(name = "Filter", description = "过滤数据集")
    public static <T> Dataset<T> filter(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "参数", description = "Filter组件参数") JSONArray array) {
        if(array.isEmpty()) {
            logger.info("Filter组件的参数为空,返回原数据集");
            return data;
        }
        if(data == null) {
            logger.info("Filter组件参数中的数据集为空，返回null");
            return null;
        }

        List<FilterSection> filterSections = new ArrayList<>();
        for(int i=0; i<array.size(); i++) {
            JSONObject object = array.getJSONObject(i);
            FilterSection filterSection = new FilterSection();
            String column = object.getString("selector").trim();
            if(column.length() == 0) {
                logger.info("Filter组件参数中的column为空, 相应的expression为： " +
                        object.getString("selectorValue") + "value为: " +
                        object.getString("method") + ", 跳过该行条件");
                continue;
            }
            filterSection.setColumn(column);
            filterSection.setExpression(object.getString("selectorValue").trim());
            // 替换转义字符
            String value = object.getString("method").trim().replace("%5c","\\");
            filterSection.setValue(value);
            filterSections.add(filterSection);
        }
        if(filterSections.isEmpty()) {
            logger.info("FilterOperator--参数为空，返回原数据集");
            return data;
        }

        return filter(data, filterSections);
    }

    protected static <T> Dataset<T> filter(Dataset<T> data, List<FilterSection> filterSections) {
        int[] positions = new int[filterSections.size()];
        StructField[] fields = data.schema().fields();
        for(int i=0; i<filterSections.size(); i++) {
            positions[i] = -1;
            for(int j=0; j<fields.length; j++) {
                if(filterSections.get(i).getColumn().equals(fields[j].name())) {
                    positions[i] = j;
                    break;
                }
            }
            if(positions[i] == -1) {
                logger.info("FilterOperator--" + filterSections.get(i).getColumn() + "列不存在");
            }
        }

//        不用lambda, 因为可能会报错, 具体看 SharedUDFs 类的开头注释部分, _(:3 」∠)_
        Dataset<Row> result = data.toDF().filter(new FilterFunction<Row>() {
            @Override
            public boolean call(Row r) throws Exception  {
                boolean flag = true;
                for(int i=0; i<positions.length; i++) {
                    if(positions[i] != -1) {
                        if(!judge(r.get(positions[i]), fields[positions[i]], filterSections.get(i))) {
                            return false;
                        }
                    }
                }
                return true;
            }});
        return (Dataset<T>) result;
    }

    protected static boolean judge(Object obj, StructField field, FilterSection filterSection) {
        boolean flag1 = obj == null;
        boolean flag2 = filterSection.getValue() == null || filterSection.getValue().length() == 0;
        if(!flag1 && !flag2) {
            if(Consts.EQUALS.equals(filterSection.getExpression())) {
                return equals(obj, filterSection);
            }
            if(Consts.DOES_NOT_EQUALS.equals(filterSection.getExpression())) {
                return !equals(obj, filterSection);
            }
            if(Consts.IS_IN.equals(filterSection.getExpression())) {
                return isIn(obj, field, filterSection);
            }
            if(Consts.IS_NOT_IN.equals(filterSection.getExpression())) {
                return !isIn(obj, field, filterSection);
            }
            if(Consts.CONTAINS.equals(filterSection.getExpression()) && field.dataType() == DataTypes.StringType) {
                return contains(obj, filterSection);
            }
            if(Consts.DOES_NOT_CONTAINS.equals(filterSection.getExpression()) && field.dataType() == DataTypes.StringType) {
                return !contains(obj, filterSection);
            }
            if(Consts.START_WITH.equals(filterSection.getExpression()) && field.dataType() == DataTypes.StringType) {
                return startWith(obj, filterSection);
            }
            if(Consts.END_WITH.equals(filterSection.getExpression()) && field.dataType() == DataTypes.StringType) {
                return endWith(obj, filterSection);
            }
            if(Consts.MATCHES.equals(filterSection.getExpression()) && field.dataType() == DataTypes.StringType) {
                return matches(obj, filterSection);
            }
        } else {
            if(Consts.IS_MISSING.equals(filterSection.getExpression())) {
                return flag1;
            }
            if(Consts.IS_NOT_MISSING.equals(filterSection.getExpression())) {
                return !flag1;
            }
            if(flag2) {
                logger.info("FilterOperator--" + field.name() + "列, expression为" + filterSection.getExpression() + "的value参数为空");
            }
        }
        return false;
    }

//    以下方法代码有很多重复的, 只存在微小差别, (:３っ)∋
//    用||表示多个条件的并集,如果加上以下&&表示交集,有点复杂; 交集就用另外添加一个条件代替
    private static boolean equals(Object obj, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {   //这符号基本没人会写
            String[] strings = value.split("&&");
            for(String s : strings) {
                s = s.trim();
                if(!obj.toString().equals(s.substring(1, s.length() - 1))) {
                    return false;
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                s = s.trim();
                if(obj.toString().equals(s.substring(1, s.length() - 1))) {
                    return true;
                }
            }
            return false;
        }
        return obj.toString().equals(value);
    }

    private static boolean contains(Object obj, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {
            String[] strings = value.split("&&");
            for(String s : strings) {
                s = s.trim();
                if(!obj.toString().contains(s.substring(1, s.length() - 1))) {
                    return false;
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                s = s.trim();
                if(obj.toString().contains(s.substring(1, s.length() - 1))) {
                    return true;
                }
            }
            return false;
        }
        return obj.toString().contains(value);
    }

    private static boolean startWith(Object obj, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {  //这符号基本没人会写
            String[] strings = value.split("&&");
            for(String s : strings) {
                s = s.trim();
                if(!obj.toString().startsWith(s.substring(1, s.length() - 1))) {
                    return false;
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                s = s.trim();
                if(obj.toString().startsWith(s.substring(1, s.length() - 1))) {
                    return true;
                }
            }
            return false;
        }
        return obj.toString().startsWith(value);
    }

    private static boolean endWith(Object obj, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {  //这符号基本没人会写
            String[] strings = value.split("&&");
            for(String s : strings) {
                s = s.trim();
                if(!obj.toString().endsWith(s.substring(1, s.length() - 1))) {
                    return false;
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                s = s.trim();
                if(obj.toString().endsWith(s.substring(1, s.length() - 1))) {
                    return true;
                }
            }
            return false;
        }
        return obj.toString().endsWith(value);
    }

    private static boolean matches(Object obj, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {  //这符号基本没人会写
            String[] strings = value.split("&&");
            for(String s : strings) {
                s = s.trim();
                if(!obj.toString().matches(s.substring(1, s.length() - 1))) {
                    return false;
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                s = s.trim();
                if(obj.toString().matches(s.substring(1, s.length() - 1))) {
                    return true;
                }
            }
            return false;
        }
        return obj.toString().matches(value);
    }

    private static boolean isIn(Object obj, StructField field, FilterSection filterSection) {
        String value = filterSection.getValue();
        if(value.contains("&&")) {  //这符号基本没人会写
            String[] strings = value.split("&&");
            for(String s : strings) {
                if(s.trim().matches("(\\[|\\()[^\\[\\(\\]\\)]+, ?[^\\[\\(\\]\\)]+(\\]|\\))")) {
                    if(!judgeRange(obj, s.trim(), field, filterSection)) {
                        return false;
                    }
                } else {
                    throw new CustomException("Filter组件参数中的column为" + filterSection.getColumn() + ", 相应的expression为： " +
                            filterSection.getExpression() + "value为: " +
                            filterSection.getValue() + ", value不符合格式, 跳过该行条件");
                }
            }
            return true;
        }
        if(value.contains("||")) {
            String[] strings = value.split("\\|\\|");
            for(String s : strings) {
                if(s.trim().matches("(\\[|\\()[^\\[\\(\\]\\)]+, ?[^\\[\\(\\]\\)]+(\\]|\\))")) {
                    if(judgeRange(obj, s.trim(), field, filterSection)) {
                        return true;
                    }
                } else {
                    throw new CustomException("Filter组件参数中的column为" + filterSection.getColumn() + ", 相应的expression为： " +
                            filterSection.getExpression() + "value为: " +
                            filterSection.getValue() + ", value不符合格式, 跳过该行条件");
                }
            }
            return false;
        }
        return judgeRange(obj, value, field, filterSection);
    }

    private static boolean judgeRange(Object obj, String value, StructField field, FilterSection filterSection) {
        boolean b1 = value.charAt(0) == '[' ;
        boolean b2 = value.charAt(value.length() - 1) == ']' ;
        String[] splits = value.substring(1, value.length() - 1).split(",");
        String s1 = splits[0].trim();
        String s2 = splits[1].trim();
        if(s1.length() == 0 || s2.length() == 0) {
            throw new CustomException("FilterOperator--" + field.name() + "列, expression为" + filterSection.getExpression() + "的value参数不符合格式");
        }
        if(SharedMethods.isNumericType(field)) {
            double d1, d2, d ;
            try {
                d1 = Consts.NEGATIVE_INFINITY.equals(s1) ? Double.NEGATIVE_INFINITY : Double.parseDouble(s1);
                d2 = Consts.NEGATIVE_INFINITY.equals(s2) ? Double.POSITIVE_INFINITY : Double.parseDouble(s2);
                if(d1 > d2) {
                    logger.info("FilterOperator--is in的value中,范围" + (b1 ? "[" : "(") + d1 + ", " + d2 + (b2 ? "]" : ")") +  ", " + d1 + "大于" + d2 + ",两者互换");
                    double t = d1;
                    d1 = d2;
                    d2 = t;
                }
                d = Double.parseDouble(obj.toString());
            } catch (NumberFormatException e) {
                throw new CustomException("FilterOperator--String转Double失败");
            }
            return (b1 ? d >= d1 : d > d1 ) && (b2 ? d <= d2 : d < d2);
        }
        if(field.dataType() == DataTypes.DateType || field.dataType() == DataTypes.TimestampType) {
            SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");
            java.util.Date date1 ;
            java.util.Date date2 ;
            try {
                date1 = format.parse(s1);
                date2 = format.parse(s2);
                if(field.dataType() == DataTypes.DateType) {
                    Date date = (Date) obj;
                    int sign1 = date.compareTo(new Date(date1.getTime()));
                    int sign2 = date.compareTo(new Date(date2.getTime()));
                    return (b1 ? sign1 >=0 : sign1 > 0) && (b2 ? sign2 <= 0 : sign2 < 0);
                }
                if(field.dataType() == DataTypes.TimestampType) {
                    Timestamp timestamp = (Timestamp) obj;
                    int sign1 = timestamp.compareTo(new Date(date1.getTime()));
                    int sign2 = timestamp.compareTo(new Date(date2.getTime()));
                    return (b1 ? sign1 >=0 : sign1 > 0) && (b2 ? sign2 <= 0 : sign2 < 0);
                }
            } catch (ParseException e) {
                logger.error("FilterOperator--String转Date失败");
                e.printStackTrace();
            }
        } else {
            logger.info(field.name() + "列的类型不支持 is in 操作, 类型为" + field.dataType().typeName());
        }

        return false;
    }

}
