package com.datastax.insight.ml.spark.ml.feature.extractor;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.feature.Word2Vec;
import org.apache.spark.ml.feature.Word2VecModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class Word2Vector implements DataSetOperator {
    public static Word2Vec getOperator(String inputCol, String outputCol, int vectorSize,int windowSize,
                                       double stepSize,int numPartitions,int maxIterations,long seed,int minCount,int maxSentenceLength){
        Word2Vec word2Vec = new Word2Vec()
                .setInputCol(inputCol)
                .setOutputCol(outputCol)
                .setVectorSize(vectorSize)
                .setWindowSize(windowSize)
                .setStepSize(stepSize)
                .setNumPartitions(numPartitions)
                .setMaxIter(maxIterations)
                .setSeed(seed)
                .setMinCount(minCount)
                .setMaxSentenceLength(maxSentenceLength);
        return word2Vec;
    }

    public static Word2VecModel fit(Dataset<Row> data, String inputCol, String outputCol,  int vectorSize,int windowSize,
                                    double stepSize,int numPartitions,int maxIterations,long seed,int minCount,int maxSentenceLength){
        Word2Vec word2Vec=getOperator(inputCol,outputCol,vectorSize,windowSize,stepSize,numPartitions,maxIterations,seed,minCount,maxSentenceLength);
        Word2VecModel model=word2Vec.fit(data);
        return model;
    }

    public static Word2VecModel fit(Word2Vec word2Vec,Dataset<Row> data){
        Word2VecModel model=word2Vec.fit(data);
        return model;
    }

    public static Dataset<Row> transform(Word2VecModel model,Dataset<Row> data){
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }

}
