package com.datastax.data.prepare.spark.dataset;

import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.datastax.data.prepare.util.Consts;
import com.datastax.data.prepare.util.CustomException;
import org.apache.spark.sql.Dataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OutlierDetectionOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(OutlierDetectionOperator.class);

    @InsightComponent(name = "离群点检测(距离)", type = "com.datastax.insight.dataprprocess.detectOutlier.distance", description = "通过第k邻近距离判断数据点，取k邻近距离的所有点的平均距离的前n个点作为离群点")
    public static <T> Dataset<T> distanceOutlier (
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "列名", description = "选择离群点的列名，用分号隔开") String columns,
            @InsightComponentArg(name = "近邻", description = "通过设置近邻点的数量来得到k邻近距离") int neighbors,
            @InsightComponentArg(name = "离群点数量", description = "设置离群点数量, 取距离前n个点") int outliers,
            @InsightComponentArg(name = "距离方法", description = "距离计算方法", defaultValue = "欧式距离", items = "欧式距离;平方距离;余弦距离;反余弦距离") String type) {
        if(data == null) {
            throw new CustomException("densities离群点检测--数据集为空");
        }
        if(neighbors <= 0) {
            throw new CustomException("densities离群点检测--距离某点的第k点距离的k小于或者等于0");
        }
        if(outliers <= 0) {
            throw new CustomException("distances的离群点检测--离群点数量n小于或者等于0");
        }
        if(columns == null || columns.length() == 0) {
            throw new CustomException("distances的离群点检测--选择的列名为空");
        }
        String[] cols = columns.split(Consts.DELIMITER);
        String[] result = OutlierDetection.getNeedStructFields(data.schema(), cols);
        if(result.length == 0) {
            logger.info("distances的离群点检测--选中的列中没有可用于离群点计算, 返回原数据集");
            return data;
        }
        if(data.count() >= 10000) {
            logger.info("数据集大于10000行，暂时不支持于离群点检测，返回原数据集");
            return data;
        }
        return (Dataset<T>) OutlierDetection.distanceOutlier(data.toDF(), result, neighbors, outliers, type);
    }

    @InsightComponent(name = "离群点检测(密度)", type = "com.datastax.insight.dataprprocess.detectOutlier.densities", description = "通过某一点在距离d范围内的点和所有点的比例得出点的密度, 再和设置的概率相比得出离群点")
    public static <T> Dataset<T> densitiesOutlier (
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "列名", description = "选择离群点的列名，用分号隔开") String columns,
            @InsightComponentArg(name = "距离", description = "距离") double distance,
            @InsightComponentArg(name = "比例", description = "小于该比例的点将设为离群点") double proportion,
            @InsightComponentArg(name = "距离方法", description = "距离计算方法", defaultValue = "欧式距离", items = "欧式距离;平方距离;余弦距离;反余弦距离") String type) {
        if(data == null) {
            throw new CustomException("densities离群点检测--数据集为空");
        }
        if(distance <= 0) {
            throw new CustomException("densities离群点检测--distance值小于等于0");
        }
        if(proportion <= 0 || proportion >= 1) {
            throw new CustomException("densities离群点检测--proportion值为空");
        }
        if(columns == null || columns.length() == 0) {
            throw new CustomException("densities离群点检测--选择的列名为空");
        }
        String[] cols = columns.split(Consts.DELIMITER);
        String[] result = OutlierDetection.getNeedStructFields(data.schema(), cols);
        if(result.length == 0) {
            logger.info("densities离群点检测--选中的列中没有可用于离群点计算, 返回原数据集");
            return data;
        }
        if(data.count() >= 10000) {
            logger.info("数据集大于10000行，暂时不支持于离群点检测，返回原数据集");
            return data;
        }
        return (Dataset<T>) OutlierDetection.densitiesOutlier(data.toDF(), result, distance, proportion, type);
    }

    @InsightComponent(name = "离群点检测(LOF)", type = "com.datastax.insight.dataprprocess.detectOutlier.lof", description = "通过LOF判断离群点")
    public static <T> Dataset<T> LOFOutlier (
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "列名", description = "选择离群点的列名，用分号隔开") String columns,
            @InsightComponentArg(name = "下限", description = "第k邻近点的k的下限") int lower,
            @InsightComponentArg(name = "上限", description = "第k邻近点的k的上限") int upper,
            @InsightComponentArg(name = "距离方法", description = "距离计算方法", defaultValue = "欧式距离", items = "欧式距离;平方距离;余弦距离;反余弦距离") String type) {
        if(data == null) {
            throw new CustomException("LOF离群点检测--数据集为空");
        }
        if(lower <= 0 || upper <= 0) {
            throw new CustomException("LOF离群点检测--上限或者下限小于或等于0");
        }
        if(columns == null || columns.length() == 0) {
            throw new CustomException("LOF离群点检测--选择的列名为空");
        }
        String[] cols = columns.split(Consts.DELIMITER);
        String[] result = OutlierDetection.getNeedStructFields(data.schema(), cols);
        if(result.length == 0) {
            logger.info("LOF离群点检测--选中的列中没有可用于离群点计算, 返回原数据集");
            return data;
        }
        if(upper < lower) {
            int temp = upper;
            upper = lower;
            lower = temp;
        }
        long count = data.count();
        if(count < lower) {
            throw new CustomException("LOF离群点检测--下限大于数据集的行数");
        }
        if(count > upper) {
            upper = (int) count;
        }
        if(data.count() >= 10000) {
            logger.info("数据集大于10000行，暂时不支持于离群点检测，返回原数据集");
            return data;
        }
        return (Dataset<T>) OutlierDetection.LOFOutlier(data.toDF(), result, lower, upper, type);
    }

    //先放着
    protected static <T> Dataset<T> COFOutlier (Dataset<T> data, String columns, int neighbors, int classOutliers) {
        return data;
    }

}
