package com.datastax.data.prepare.spark.dataset;

import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import com.datastax.data.prepare.util.SharedMethods;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

public class MultiStringIndexerOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(MultiStringIndexerOperator.class);

    @InsightComponent(name = "StringIndexer", description = "将字符串转换成索引，和标签数值化转换相同，支持多列转换")
    public static <T> Dataset<T> multiStringIndexer(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> dataset,
            @InsightComponentArg(name = "列名", description = "需要转换的列名，多个列名用分号隔开") String column,
            @InsightComponentArg(name = "转换后的列名", description = "转换生成的索引列的列名，不能与现有列名重复") String indexerColumnName) {
        if(dataset == null) {
            logger.info("数据集为空");
            return null;
        }
        if(column == null || column.length() == 0) {
            throw new NullPointerException("StringIndexer组件的参数为空");
        }
        Map<String, Object[]> map = new HashMap<>();
        SharedMethods.recordSchema(dataset.schema().fields(), map);

        String[] columns = column.split(Consts.DELIMITER);
        String[] results = indexerColumnName.split(Consts.DELIMITER);
        if(columns.length != results.length) {
            throw new CustomException("StringIdexer组件的列名和转换后的列名的数量不等");
        }
        Dataset<Row> data = dataset.toDF();
        for(int i = 0; i < columns.length; i++) {
            String c = columns[i].trim();
            String r = results[i].trim();
            if(c.length() == 0) {
                logger.info("列名参数的第" + (i + 1) + "个参数去掉前后空格后为空，跳过");
                continue;
            }
            if(!map.containsKey(c)) {
                logger.info("数据集中找不到" + c + "列，跳过");
                continue;
            }
            if(r.length() == 0) {
                throw new CustomException("转换后的列名参数的第" + (i + 1) + "个参数去掉前后空格后为空");
            }
            if(map.containsKey(r)) {
                throw new CustomException("转换后生成的列名" + r + "和现有列名冲突");
            }
            data = new StringIndexer()
                    .setInputCol(c)
                    .setOutputCol(r)
                    .fit(data)
                    .transform(data);
            map.put(r, new Object[]{map.size() + 1, DataTypes.IntegerType});
        }
        return (Dataset<T>) data;
    }
}
