package com.datastax.insight.ml.spark.ml.feature.extractor;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class CountVectorizerWrapper implements DataSetOperator {
    public static CountVectorizer getOperator(String inputCol, String outputCol, int vocabSize,
                                              double minDF,double minTF,boolean binary){
        CountVectorizer countVectorizer = new CountVectorizer()
                .setInputCol(inputCol)
                .setOutputCol(outputCol)
                .setVocabSize(vocabSize)
                .setMinDF(minDF)
                .setMinTF(minTF)
                .setBinary(binary);
        return countVectorizer;
    }

    public static CountVectorizerModel fit(CountVectorizer countVectorizer,Dataset<Row> data){
        CountVectorizerModel model=countVectorizer.fit(data);
        return model;
    }

    public static CountVectorizerModel fit(Dataset<Row> data, String inputCol, String outputCol, int vocabSize,
                                           double minDF,double minTF,boolean binary){
        CountVectorizer countVectorizer=getOperator(inputCol,outputCol,vocabSize,minDF,minTF,binary);
        CountVectorizerModel model=countVectorizer.fit(data);
        return model;
    }

    public static Dataset<Row> transform(CountVectorizerModel model,Dataset<Row> data){
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }
}
