package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.core.Consts;
import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class VectorAssemblerWrapper implements DataSetOperator {
    public static VectorAssembler getOperator(String inputCol, String outputCol){
        String[] inputCols=inputCol.split(Consts.DELIMITER);

        VectorAssembler assembler = new VectorAssembler()
                .setInputCols(inputCols)
                .setOutputCol(outputCol);
        return assembler;
    }

    public static Dataset<Row> transform(Dataset<Row> data, String inputCol, String outputCol){

        String[] columns=inputCol.split(Consts.DELIMITER);

        //TODO: 稍候再修改成更加高效的方式
        Dataset<Row> result = null;
        for (String c : columns) {

            if(result == null) {
                result = data.withColumn(c, data.col(c).cast("double"));
            } else {
                result = result.withColumn(c, data.col(c).cast("double"));
            }
        }

        VectorAssembler assembler=getOperator(inputCol,outputCol);
        Dataset<Row> tdata=assembler.transform(result);
        return tdata;
    }

    public static Dataset<Row> transform(VectorAssembler assembler,Dataset<Row> data){
        Dataset<Row> tdata=assembler.transform(data);
        return tdata;
    }
}
