package com.datastax.insight.ml.spark.mllib.feature;

import com.datastax.insight.spec.RDDOperator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.feature.ChiSqSelector;
import org.apache.spark.mllib.feature.ChiSqSelectorModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

public class ChiSqFeatureSelector implements RDDOperator {
    public static JavaRDD<LabeledPoint> fit(JavaRDD<LabeledPoint> data,int bins,int numFeatures){
        JavaRDD<LabeledPoint> discretizedData = data.map(
                new Function<LabeledPoint, LabeledPoint>() {
                    @Override
                    public LabeledPoint call(LabeledPoint lp) {
                        final double[] discretizedFeatures = new double[lp.features().size()];
                        for (int i = 0; i < lp.features().size(); ++i) {
                            discretizedFeatures[i] = Math.floor(lp.features().apply(i) / bins);
                        }
                        return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures));
                    }
                }
        );

        ChiSqSelector selector = new ChiSqSelector(numFeatures);
        // Create ChiSqSelector model (selecting features)
        final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd());
        // Filter the top X features from each feature vector
        JavaRDD<LabeledPoint> filteredData = discretizedData.map(
                new Function<LabeledPoint, LabeledPoint>() {
                    @Override
                    public LabeledPoint call(LabeledPoint lp) {
                        return new LabeledPoint(lp.label(), transformer.transform(lp.features()));
                    }
                }
        );
        return filteredData;
    }
}
