package com.datastax.data.prepare.spark.dataset;

import com.alibaba.fastjson.JSONArray;
import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class DataSampleOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(DataSampleOperator.class);

    @InsightComponent(name = "数据采样", description = "数据采样", order = 500304)
    public static <T> Dataset<T> dataSample(
            @InsightComponentArg(externalInput = true, name = "data", description = "数据集") Dataset<T> dataset,
            @InsightComponentArg(name = "withReplacement", description = "是否放回抽样") boolean withReplacement,
            @InsightComponentArg(name = "sampleType", description = "Sample类型", items = "absolute,relative,probability", defaultValue = "absolute") String sampleType,
            @InsightComponentArg(name = "editList", description = "配置列表") String json) throws IOException {
        if (dataset == null) {
            logger.info("dataset is empty!");
            return null;
        }

        StructType schema = dataset.schema();
        List<List> propertyList = parseJsonArray(json);
        switch (sampleType) {
            case "absolute":
                dataset = absoluteSample(dataset, withReplacement, schema, propertyList);
                break;
            case "relative":
                dataset = relativeSample(dataset, withReplacement, schema, propertyList);
                break;
            case "probability":
                dataset = probabilitySample(dataset, withReplacement, propertyList);
                break;
            default:
                return dataset;
        }
        return dataset;
    }

    private static List<List> parseJsonArray(String json) throws IOException {
        List<List> propertyList = new ArrayList<>();
        ObjectMapper objectMapper = new ObjectMapper();
        List<LinkedHashMap<String, Object>> list = objectMapper.readValue(json, List.class);
        for (Map<String, Object> map : list) {
            List property = new ArrayList<>();
            property.add(map.get("selector"));
            property.add(map.get("selectorValue"));
            propertyList.add(property);
        }
        return propertyList;
    }

    private static <T> Dataset<T> absoluteSample(Dataset<T> dataset, boolean withReplacement, StructType schema, List<List> propertylist) {
        List<Row> sampleList = new ArrayList<>();
        Dataset ds;
        if (!propertylist.get(0).get(0).equals("balance")) {
            if (Double.parseDouble(propertylist.get(0).get(1).toString()) % 1 != 0) {
                logger.info("absoluteSample num must be int type!");
                return dataset;
            }
            int num = Integer.parseInt(propertylist.get(0).get(1).toString());
            if (num > 0) {
                sampleList.addAll((List<Row>) dataset.javaRDD().takeSample(withReplacement, num));
                return (Dataset<T>) SparkContextBuilder.getSession().createDataFrame(sampleList, schema);
            } else {
                logger.info("absoluteSample num is less than 0!");
                return dataset;
            }
        }
        List<List> balancePropertyList = propertylist.subList(1, propertylist.size());
        for (int i = 0; i < balancePropertyList.size(); i++) {
            if (Double.parseDouble(balancePropertyList.get(i).get(1).toString()) % 1 != 0) {
                logger.info("absoluteSample num must be int type!");
                return dataset;
            }
            int num = Integer.parseInt(balancePropertyList.get(i).get(1).toString());
            ds = dataset.filter(balancePropertyList.get(i).get(0).toString());
            if (num > 0) {
                sampleList.addAll((List<Row>) ds.javaRDD().takeSample(withReplacement, num));
            } else {
                logger.info("absoluteSample num is less than 0!");
                return dataset;
            }
        }
        return (Dataset<T>) SparkContextBuilder.getSession().createDataFrame(sampleList, schema);

    }

    private static <T> Dataset<T> relativeSample(Dataset<T> dataset, boolean withReplacement, StructType schema, List<List> propertylist) {
        List<Row> sampleList = new ArrayList<>();
        Dataset ds;
        if (!propertylist.get(0).get(0).equals("balance")) {
            int num = (int) (Double.parseDouble(propertylist.get(0).get(1).toString()) * dataset.collectAsList().size());
            if (num > 0){
                sampleList.addAll((List<Row>) dataset.javaRDD().takeSample(withReplacement, num));
                return (Dataset<T>) SparkContextBuilder.getSession().createDataFrame(sampleList, schema);
            } else {
                logger.info("relativeSample ratio is less than 0!");
                return dataset;
            }
        }
        List<List> balancePropertyList = propertylist.subList(1, propertylist.size());
        for (List aBalancePropertyList : balancePropertyList) {
            ds = dataset.filter(aBalancePropertyList.get(0).toString());
            int num = (int) (Double.parseDouble(aBalancePropertyList.get(1).toString()) * ds.collectAsList().size());
            if (num > 0) {
                sampleList.addAll((List<Row>) ds.javaRDD().takeSample(withReplacement, num));
            } else {
                logger.info("relativeSample ratio is less than 0!");
                return dataset;
            }
        }
        return (Dataset<T>) SparkContextBuilder.getSession().createDataFrame(sampleList, schema);

    }

    private static <T> Dataset<T> probabilitySample(Dataset<T> dataset, boolean withReplacement, List<List> propertylist) {
        Dataset ds = null;
        if (!propertylist.get(0).get(0).equals("balance")) {
            double probability = Double.parseDouble(propertylist.get(0).get(1).toString());
            if (probability > 0 && probability < 1) {
                return dataset.sample(withReplacement, probability);
            } else {
                logger.info("probabilitySample probability is out of range!");
                return dataset;
            }
        }
        List<List> balancePropertyList = propertylist.subList(1, propertylist.size());
        for (int i = 0; i < balancePropertyList.size(); i++) {
            double probability = Double.parseDouble(balancePropertyList.get(i).get(1).toString());
            if (probability > 0 && probability < 1) {
                if (i == 0) {
                    ds = dataset.filter(balancePropertyList.get(i).get(0).toString()).sample(withReplacement, probability);
                } else {
                    ds = ds.union(dataset.filter(balancePropertyList.get(i).get(0).toString()).sample(withReplacement, probability));
                }
            } else {
                logger.info("probabilitySample probability is out of range!");
                return dataset;
            }
        }
        return ds;

    }

    //todo andy 怎么又2个一样的？就差了一个to string
    @InsightComponent(name = "数据采样", description = "数据采样", order = 500304)
    public static <T> Dataset<T> dataSample(
            @InsightComponentArg(externalInput = true, name = "data", description = "数据集") Dataset<T> dataset,
            @InsightComponentArg(name = "withReplacement", description = "是否放回抽样") boolean withReplacement,
            @InsightComponentArg(name = "sampleType", description = "Sample类型", items = "absolute,relative,probability", defaultValue = "absolute") String sampleType,
            @InsightComponentArg(name = "propertylist", description = "参数列表") JSONArray jsonArray) throws IOException {
        return dataSample(dataset, withReplacement, sampleType, jsonArray.toString());
    }
    //todo andy 类型转换写得很怪
}
