package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 标签数值化
 */
public class StringIndexerWrapper implements DataSetOperator {
    /**
     * 标签数值化转换器
     */
    public static StringIndexer getOperator(String inputCol, String outputCol,String handleInvalid){
        StringIndexer indexer=new StringIndexer()
                .setInputCol(inputCol)
                .setOutputCol(outputCol);

        if(!Strings.isNullOrEmpty(handleInvalid)) {
            indexer = indexer.setHandleInvalid(handleInvalid);
        }

        return indexer;
    }
    public static StringIndexerModel fit(Dataset<Row> data,String inputCol, String outputCol,String handleInvalid){
        StringIndexer indexer=getOperator(inputCol,outputCol,handleInvalid);
        StringIndexerModel model=indexer.fit(data);
        return model;
    }

    /**
     * 标签数值化训练
     */
    public static StringIndexerModel fit(StringIndexer indexer,Dataset<Row> data){
        StringIndexerModel model=indexer.fit(data);
        return model;
    }

    /**
     * 标签数值化转换
     */
    public static Dataset<Row> transform(StringIndexerModel model, Dataset<Row> data){
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }
}
