package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import com.google.common.base.Strings;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 分箱
 */
public class BucketizerWrapper implements DataSetOperator {
    /**
     * 分箱器
     */
    public static Bucketizer getOperator(String inputCol, String outputCol, String sp, String handleInvalid){
        double[] splits = null;
        if (sp != null && sp.length() > 0) {
            String[] sps = sp.split(Consts.DELIMITER);
            splits = new double[sps.length];
            for (int i = 0; i < splits.length; i++) {
                splits[i] = Double.parseDouble(sps[i]);
            }
        }

        Bucketizer bucketizer = new Bucketizer();

        if(!Strings.isNullOrEmpty(inputCol)) {
            bucketizer.setInputCol(inputCol);
        }

        if(!Strings.isNullOrEmpty(outputCol)) {
            bucketizer.setOutputCol(outputCol);
        }

        if(!Strings.isNullOrEmpty(handleInvalid)) {
            bucketizer.setHandleInvalid(handleInvalid);
        }

        if(splits!=null){
            bucketizer.setSplits(ArrayUtils.add(ArrayUtils.addAll(new double[] { Double.NEGATIVE_INFINITY}, splits), Double.POSITIVE_INFINITY));
        }

        return bucketizer;
    }

    public static Dataset<Row> transform(Dataset<Row> data, String inputCol, String outputCol, String sp, String handleInvalid){
        Bucketizer bucketizer=getOperator(inputCol,outputCol,sp, handleInvalid);
        Dataset<Row> tdata=bucketizer.transform(data);
        return tdata;
    }

    /**
     * 分箱器转换
     */
    public static Dataset<Row> transform(Bucketizer bucketizer,Dataset<Row> data){
        Dataset<Row> tdata=bucketizer.transform(data);
        return tdata;
    }
}
