package com.datastax.insight.ml.spark.ml.feature.transformer;

import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import org.apache.spark.ml.feature.ElementwiseProduct;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * Hadamard乘积
 */
public class HadamardProduct implements DataSetOperator {
    /**
     * Hadamard乘积
     */
    public static ElementwiseProduct getOperator(String inputCol, String outputCol, String vectorText){
        double[] vectors = null;
        if (vectorText != null && vectorText.length() > 0) {
            String[] vs=vectorText.split(Consts.DELIMITER);
            vectors = new double[vs.length];
            for (int i = 0; i < vectors.length; i++) {
                vectors[i] = Double.parseDouble(vs[i]);
            }
        }

        ElementwiseProduct transformer = new ElementwiseProduct()
                .setInputCol(inputCol)
                .setOutputCol(outputCol);

        if(vectors!=null){
            transformer.setScalingVec(Vectors.dense(vectors));
        }

        return transformer;
    }

    public static Dataset<Row> transform(Dataset<Row> data, String inputCol, String outputCol,String vectorText){
        ElementwiseProduct transformer=getOperator(inputCol,outputCol,vectorText);
        Dataset<Row> tdata=transformer.transform(data);
        return tdata;
    }

    /**
     * Hadamard乘积转换
     */
    public static Dataset<Row> transform(ElementwiseProduct transformer,Dataset<Row> data){
        Dataset<Row> tdata=transformer.transform(data);
        return tdata;
    }
}
