package com.datastax.insight.ml.spark.ml.feature.extractor;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IDFModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

public class TFIDF implements DataSetOperator {
    public static HashingTF getHashingTF(String inputCol,String outputCol, Integer numFeatures, Boolean binary){
        HashingTF hashingTF=new HashingTF();

        if(!Strings.isNullOrEmpty(inputCol)) {
            hashingTF.setInputCol(inputCol);
        }

        if(!Strings.isNullOrEmpty(outputCol)) {
            hashingTF.setOutputCol(outputCol);
        }

        if(numFeatures != null) {
            hashingTF.setNumFeatures(numFeatures);
        }

        if(binary != null) {
            hashingTF.setBinary(binary);
        }

        return hashingTF;
    }

    public static Dataset<Row> hashingTF(Dataset<Row> data,String inputCol,String outputCol,int numFeatures,boolean binary){
        HashingTF hashingTF=getHashingTF(inputCol,outputCol,numFeatures,binary);
        Dataset<Row> tdata=hashingTF.transform(data);
        return tdata;
    }

    public static Dataset<Row> hashingTF(HashingTF hashingTF,Dataset<Row> data){
        Dataset<Row> tdata=hashingTF.transform(data);
        return tdata;
    }

    public static IDF getIDF(String inputCol, String outputCol, int minDocFreq){
        IDF idf=new IDF()
                .setInputCol(inputCol)
                .setOutputCol(outputCol)
                .setMinDocFreq(minDocFreq);
        return idf;
    }

    public static IDFModel idfFit(Dataset<Row>data,String inputCol, String outputCol, int minDocFreq){
        IDF idf=getIDF(inputCol,outputCol,minDocFreq);
        IDFModel model=idf.fit(data);
        return model;
    }

    public static IDFModel idfFit(IDF idf,Dataset<Row>data){
        IDFModel model=idf.fit(data);
        return model;
    }

    public static Dataset<Row> idfTransform(IDFModel model,Dataset<Row> data){
        Dataset<Row> tdata=model.transform(data);
        return tdata;
    }
}
