package com.datastax.data.prepare.spark.dataset;

import com.datastax.insight.core.driver.SparkContextBuilder;
import com.datastax.insight.spec.Operator;
import com.datastax.insight.annonation.InsightComponent;
import com.datastax.insight.annonation.InsightComponentArg;
import com.google.common.base.Strings;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.UDFRegistration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;

public class UDDOperator implements Operator {
    private static final Logger logger = LoggerFactory.getLogger(UDDOperator.class);

    @InsightComponent(name = "用户自定函数", type = "com.datastax.insight.dataprprocess.udd", description = "用户自定函数")
    public static <T> Dataset<T> udd(
            @InsightComponentArg(externalInput = true, name = "数据集", description = "数据集") Dataset<T> data,
            @InsightComponentArg(name = "jar路径", description = "jar文件的路径", request = true) String jarPath,
            @InsightComponentArg(name = "全限定类名", description = "要调用的类名", request = true) String className,
            @InsightComponentArg(name = "方法名", description = "用户自定函数的入口，包含参数Dataset，SparkSession，返回处理过的Dataset", request = true) String methodName) {
        if(Strings.isNullOrEmpty(jarPath) || Strings.isNullOrEmpty(className) || Strings.isNullOrEmpty(methodName)) {
            logger.info("参数为空，返回原数据集");
            return data;
        }
        File file = new File(jarPath);
        if(!file.exists()) {
            logger.info("jar文件不存在");
            return data;
        }
        SparkSession session = SparkContextBuilder.getSession();
        URLClassLoader classLoader = (URLClassLoader) ClassLoader.getSystemClassLoader();
        String temp = "file:" + jarPath;  //不一定是 file:

        try {
            Method addURL = URLClassLoader.class.getDeclaredMethod("addURL", URL.class);
            addURL.setAccessible(true);
            URL url = new URL(temp);
            addURL.invoke(classLoader, url);

            Class<?> clazz = classLoader.loadClass(className);
            Method method = clazz.getDeclaredMethod(methodName, Dataset.class, UDFRegistration.class);
            Object obj = clazz.newInstance();
            Object o = method.invoke(obj, data, session.udf());
            if(o instanceof Dataset) {
                data = (Dataset<T>) o;
            } else {
                logger.info("方法返回类型不等于 Dataset");
            }
        } catch (NoSuchMethodException e) {
            logger.error("方法不存在");
            e.printStackTrace();
        } catch (MalformedURLException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            logger.error("类不存在");
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        }
        return data;
    }

}
