package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.feature.Bucketizer;
import org.apache.spark.ml.feature.QuantileDiscretizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 分位数离散
 */
public class QuantileDiscretizerWrapper implements DataSetOperator {
    /**
     * 分位数离散器
     */
    public static QuantileDiscretizer getOperator(String inputCol, String outputCol, int numBuckets,double relativeError){
        QuantileDiscretizer discretizer = new QuantileDiscretizer()
                .setInputCol(inputCol)
                .setOutputCol(outputCol)
                .setNumBuckets(numBuckets)
                .setRelativeError(relativeError);
        return discretizer;
    }

    public static Bucketizer fit(Dataset<Row> data, String inputCol, String outputCol, int numBuckets,double relativeError){
        QuantileDiscretizer discretizer=getOperator(inputCol,outputCol,numBuckets,relativeError);
        Bucketizer model=discretizer.fit(data);
        return model;
    }

    /**
     * 分位数离散变换
     */
    public static Bucketizer fit(QuantileDiscretizer discretizer,Dataset<Row> data){
        Bucketizer model=discretizer.fit(data);
        return model;
    }

    public static Dataset<Row> transform(Dataset<Row> data, String inputCol, String outputCol,
                                         int numBuckets,double relativeError){
        Bucketizer bucketizer=fit(data,inputCol,outputCol,numBuckets,relativeError);
        Dataset<Row> tdata=bucketizer.transform(data);
        return tdata;
    }
}
