package com.datastax.data.prepare.spark.dataset;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.spark.dataset.params.FilterSection;
import com.datastax.data.prepare.spark.dataset.params.ReplaceAttribute;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import com.datastax.data.prepare.util.SharedMethods;
import org.apache.parquet.Strings;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ReplaceOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(ReplaceOperator.class);

    @InsightComponent(name = "Replace", description = "替换数据集中的值")
    public static <T> Dataset<T> replace(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "参数", description = "参数") JSONObject object) {
        if(object.isEmpty()) {
            logger.info("Replace组件参数为空, 返回原数据集");
            return data;
        }
        if(data == null) {
            logger.info("Replace组件中的数据集为空, 返回空");
            return null;
        }
        String replaceType = object.getString("selector");
        JSONArray array = object.getJSONArray("selectorValue");

        List<ReplaceAttribute> replaceAttributes = new ArrayList<>();
        for(int i=0; i<array.size(); i++) {
            JSONObject obj = array.getJSONObject(i);
            ReplaceAttribute replaceAttribute = new ReplaceAttribute();
            String attributeSelector = obj.getString("method");
            replaceAttribute.setAttributeSelector(attributeSelector);
            if(Consts.ATTRIBUTE_NAME.equals(attributeSelector)) {
                replaceAttribute.setAttribute(obj.getString("methodValue"));
            }
            if(Consts.REGULAR_EXPRESSION.equals(attributeSelector)) {
                replaceAttribute.setRegularExpression(obj.getString("methodValue").trim());
            }
            if(Consts.VALUE_TYPE.equals(attributeSelector)) {
                replaceAttribute.setValueType(obj.getString("methodValue").trim());
            }
            if(Consts.REPLACE.equals(replaceType)) {
                JSONArray valueReplaceAttributes = obj.getJSONArray("var1");
                for(int j=0; j<valueReplaceAttributes.size(); j++) {
                    JSONObject valueReplaceAttribute = valueReplaceAttributes.getJSONObject(j);
                    boolean regex = valueReplaceAttribute.getBoolean("var2");
                    String oldValue = valueReplaceAttribute.getString("var3");
                    String newValue = valueReplaceAttribute.getString("var4");
//                    if(oldValue == null || oldValue.length() == 0) {
//                        throw new CustomException("Replace组件替换的某一旧值为空");
//                    }
                    if(newValue == null || newValue.length() == 0) {
                        throw new CustomException("Replace组件替换的某一新值为空");
                    }
                    replaceAttribute.addValueReplaceAttributes(regex, oldValue, newValue);
                }

            }
            if(Consts.REPLACE_BY_CONDITION.equals(replaceType)) {
                JSONArray filters = obj.getJSONArray("var1");
                for(int j=0; j<filters.size(); j++) {
                    JSONObject filter = filters.getJSONObject(j);
                    FilterSection filterSection = new FilterSection();
                    String column = filter.getString("column").trim();
                    if(column.length() == 0) {
                        logger.info("Replace组件参数中条件的某一column为空, 相应的expression为： " +
                                filter.getString("expression") + "value为: " +
                                filter.getString("value") + ", 跳过该行条件");
                        continue;
                    }
                    filterSection.setColumn(column);
                    filterSection.setExpression(filter.getString("expression").trim());
                    filterSection.setValue(filter.getString("value").trim());
                    replaceAttribute.addFilterSections(filterSection);
                }
                JSONArray valueReplaceAttributes = obj.getJSONArray("var2");
                for(int j=0; j<valueReplaceAttributes.size(); j++) {
                    JSONObject valueReplaceAttribute = valueReplaceAttributes.getJSONObject(j);
                    boolean regex = valueReplaceAttribute.getBoolean("var3");
                    String oldValue = valueReplaceAttribute.getString("var4");
                    String newValue = valueReplaceAttribute.getString("var5");
//                    if(oldValue == null || oldValue.length() == 0) {
//                        throw new CustomException("Replace组件替换的某一旧值为空");
//                    }
                    if(newValue == null || newValue.length() == 0) {
                        throw new CustomException("Replace组件替换的某一新值为空");
                    }
                    replaceAttribute.addValueReplaceAttributes(regex, oldValue, newValue);
                }
            }
            replaceAttributes.add(replaceAttribute);
        }

        return replace(data, replaceType, replaceAttributes);
    }

    private static <T> Dataset<T> replace(Dataset<T> data, String replaceType, List<ReplaceAttribute> replaceAttributes) {
        Map<StructField, List<ReplaceAttribute>> map = new HashMap<>();
        StructField[] schema = data.schema().fields();
        for(ReplaceAttribute replaceAttribute : replaceAttributes) {
            StructField[] fields = SharedMethods.attributeFilter(data, replaceAttribute.getAttributeSelector(), replaceAttribute.isInvertSelection(),
                    replaceAttribute.getAttribute(), replaceAttribute.getRegularExpression(), replaceAttribute.getValueType());
            if(fields == null) {
                logger.info("Replace组件中选择列的结果为空,跳过该参数");
                continue;
            }
            for(StructField field : fields) {
                if(map.containsKey(field)) {
                    map.get(field).add(replaceAttribute);
                } else {
                    List<ReplaceAttribute> list = new ArrayList<>();
                    list.add(replaceAttribute);
                    map.put(field, list);
                }
            }
        }

        Map<String, Object[]> totalSchema = new HashMap<>();
        SharedMethods.recordSchema(schema, totalSchema);
        JavaRDD<Row> javaRDD = data.toDF().javaRDD().map(new Function<Row, Row>() {
            @Override
            public Row call(Row row) throws Exception {
                String[] strings = new String[row.length()];
                for(int i=0; i<row.length(); i++) {
                    boolean flag = false;
                    for(StructField field : map.keySet()) {
                        if(schema[i] == field) {
                            flag = true;
                            break;
                        }
                    }
                    if(flag) {

                        block: {
                            for (ReplaceAttribute replaceAttribute : map.get(schema[i])) {
                                for (ReplaceAttribute.ValueReplaceAttribute v : replaceAttribute.getValueReplaceAttributes()) {
                                    if (Consts.REPLACE_BY_CONDITION.equals(replaceType)) {
                                        for (FilterSection filterSection : replaceAttribute.getFilterSections()) {
                                            int p = Integer.parseInt(totalSchema.get(filterSection.getColumn())[0].toString());
                                            if (p == -1) {
                                                logger.info("条件中的" + filterSection.getColumn() + "列不存在");
                                                continue;
                                            }
                                            if (!FilterOperator.judge(row.get(p), schema[p], filterSection)) {
                                                strings[i] = row.get(i) == null ? null : row.get(i).toString();
                                                break block;
                                            }
                                        }

                                    }

//                                  以下为普通替换内容
                                    String t = row.get(i) == null ? null : row.get(i).toString();
                                    if((t == null && Strings.isNullOrEmpty(v.getOldValue())) || v.getOldValue().equals(Consts.ASTERRISK) ) {
                                        strings[i] = v.getNewValue();
                                        continue ;
                                    }
                                    if(t == null) {
                                        continue ;
                                    }
                                    strings[i] = v.isRegex() ? t.matches(v.getOldValue()) ? v.getNewValue() : t : v.getOldValue().equals(t) ? v.getNewValue() : t;
                                }
                            }
                        }
                    } else {
                        strings[i] = row.get(i) == null ? null : row.get(i).toString();
                    }

                }

                return RowFactory.create(strings);
            }
        });

//        替换后的数据全部为String类型, 能否维持原类型？！ 很难 _(´ཀ`」 ∠)_
        StructType structType = new StructType();
        for(StructField field : schema) {
            structType = structType.add(field.name(), DataTypes.StringType, true);
        }
        return (Dataset<T>) SparkContextBuilder.getSession().createDataFrame(javaRDD, structType);
    }


}
