package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.feature.StandardScaler;
import org.apache.spark.ml.feature.StandardScalerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 标准化转换
 */
public class StandardScalerWrapper implements DataSetOperator {
    /**
     * 标准化转换器
     */
    public static StandardScaler getOperator(String inputCol, String outputCol,boolean mean,boolean std){
        StandardScaler scaler=new StandardScaler()
                .setInputCol(inputCol)
                .setOutputCol(outputCol)
                .setWithMean(mean)
                .setWithStd(std);
        return scaler;
    }
    public static StandardScalerModel fit(Dataset<Row> data, String inputCol, String outputCol,boolean mean,boolean std){
        StandardScaler indexer=getOperator(inputCol,outputCol,mean,std);
        StandardScalerModel model =indexer.fit(data);
        return model;
    }

    /**
     * 标准化转换器训练
     */
    public static StandardScalerModel fit(StandardScaler indexer,Dataset<Row> data){
        StandardScalerModel model =indexer.fit(data);
        return model;
    }

    /**
     * 标准化转换
     */
    public static Dataset<Row> transform(StandardScalerModel model, Dataset<Row> data){
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }

    public static Dataset<Row> transform(Dataset<Row> data, String inputCol, String outputCol,boolean mean,boolean std){
        StandardScalerModel model = fit(data, inputCol, outputCol, mean, std);
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }
}
