package com.datastax.data.prepare.spark.dataset;

import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.DataTypes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Seq;

public class StockSimilrityOperator implements Operator {
    private static Logger logger = LoggerFactory.getLogger(StockSimilrityOperator.class);

    //    lmt_trad_dirc  set_sim_mtd  min_sim_flt  min_sim  min_acct_num
    @InsightComponent(name = "动态相似度(stock)", description = "相似度计算(stock)")
    public static <T> Dataset<T> computeSim(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<Row> data,
            @InsightComponentArg(name = "股东帐号列", description = "股东帐号") String shrAcctColumn,
            @InsightComponentArg(name = "股票列", description = "股票列") String secCodeColumn,
            @InsightComponentArg(name = "日期列", description = "日期列") String tradDateColumn,
            @InsightComponentArg(name = "交易方向列", description = "交易方向列") String tradDircColumn,
            @InsightComponentArg(name = "是否同向交易", description = "是否同向交易", defaultValue = "true", items = "true;false") boolean limitTradeDirection,
            @InsightComponentArg(name = "相似度方法", description = "相似度方法", defaultValue = "maximum_best_match", items = "maximum_best_match;average_best_match") String setSimMethodString,
            @InsightComponentArg(name = "最小有效相似度", description = "用于过滤掉两股东账户之间相似度小于该值的数据，0到100之间") double minSimFlt,
            @InsightComponentArg(name = "最小天数阈值", description = "用于计算相似度的最小天数阈值") int tradeDateThreshold,
            @InsightComponentArg(name = "市场类型列", description = "市场类型列名") String mktTypeColumn,
            @InsightComponentArg(name = "市场类型", description = "市场类型, 0代表沪市，1代表深市", defaultValue = "1", items = "0;1") String mktType,
            @InsightComponentArg(name = "中间文件路径", description = "中间文件路径", defaultValue = "${MISC_FOLDER}") String tempPath,
            @InsightComponentArg(name = "批处理数量", description = "批处理数量", defaultValue = "5000") String numString) {
        if(data == null) {
            throw new IllegalArgumentException("数据集为空");
        }
        if(setSimMethodString == null || setSimMethodString.length() == 0) {
            throw new IllegalArgumentException("set_similarity_method为空");
        }
        if(minSimFlt > 100 || minSimFlt < 0) {
            throw new IllegalArgumentException("min_similarity_flt不在0到100之间");
        }
        if(tradeDateThreshold < 0) {
            throw new IllegalArgumentException("trade_date_threshold小于0");
        }
        if(mktTypeColumn == null || mktTypeColumn.trim().isEmpty()) {
            throw new IllegalArgumentException("mkt_type的列名为空");
        }
        if(mktType == null || mktType.trim().isEmpty()) {
            throw new IllegalArgumentException("mkt_type为空");
        }
        if(tempPath == null || tempPath.trim().isEmpty()) {
            throw new IllegalArgumentException("tempPath为空");
        }
        int setSimMethod = "maximum_best_match".equals(setSimMethodString) ? 0 : 1;
        SparkSession spark = SparkContextBuilder.getSession();
//        spark.udf().register("dynamicSimUDF", new UDF3<Seq<String>, Seq<String>, Integer, Double>() {   //Boolean, Integer,
//            @Override
//            public Double call(Seq<String> stringSeq, Seq<String> stringSeq2, Integer integer2) throws Exception {   //, Boolean aBoolean, Integer integer
//                return StockSimilarity.dynamicUDF1(stringSeq, stringSeq2, integer2);   //aBoolean, integer,
//            }
//        }, DataTypes.DoubleType);
//        return (Dataset<T>) StockSimilarity.stockSim(spark, data.toDF(), shrAcctColumn, secCodeColumn, tradDateColumn, tradDircColumn, limitTradeDirection, setSimMethod,
//                minSimFlt, tradeDateThreshold, mktTypeColumn, mktType);
        return (Dataset<T>) StockSimilarity.stockSim2(spark, data.toDF(), shrAcctColumn, secCodeColumn, tradDateColumn, tradDircColumn, limitTradeDirection, setSimMethod,
                minSimFlt, tradeDateThreshold, mktTypeColumn, mktType, tempPath, Integer.parseInt(numString));
    }

    @InsightComponent(name = "静态相似度(stock)", description = "计算静态的相似度")
    public static <T> Dataset<T> computeStaticSim(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "证券账户代码列", description = "证券账户代码列") String sec_acct_code,
            @InsightComponentArg(name = "手机号码列", description = "手机号码列") String mob_nbr,
            @InsightComponentArg(name = "固定或备用联系电话列", description = "固定或备用联系电话列") String fix_or_memo_cntct_tel,
            @InsightComponentArg(name = "联系地址列", description = "联系地址列") String cntct_addr,
            @InsightComponentArg(name = "电子邮箱列", description = "电子邮箱列") String email_box,
            @InsightComponentArg(name = "身份证明文件号码列", description = "身份证明文件号码列") String identifi_file_nbr,
            @InsightComponentArg(name = "开户代理机构代码列", description = "开户代理机构代码列") String open_agt_code,
            @InsightComponentArg(name = "开户代理网点代码列", description = "开户代理网点代码列") String open_agt_net_code,
            @InsightComponentArg(name = "开户日期列", description = "开户日期列") String open_date,
            @InsightComponentArg(name = "最小相似度", description = "用于过滤掉数据集中相似度小于该值的数据") double minSim
    ) {

        SparkSession spark = SparkContextBuilder.getSession();
        spark.udf().register("staticSimUDF", new UDF2<Seq<String>, Seq<String>, Double>() {
            @Override
            public Double call(Seq<String> stringSeq, Seq<String> stringSeq2) throws Exception {
                return StockSimilarity.staticUDF(stringSeq, stringSeq2);
            }
        }, DataTypes.DoubleType);

        return (Dataset<T>) StockSimilarity.stockStaticSim(spark, data.toDF(), sec_acct_code, mob_nbr, fix_or_memo_cntct_tel, cntct_addr, email_box,
                identifi_file_nbr, open_agt_code, open_agt_net_code, open_date, minSim);

    }
}
