/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.classification;

import java.util.Arrays;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.junit.Assert;
import org.junit.Test;

public class JavaNaiveBayesSuite
extends SharedSparkSession {
    private static final List<LabeledPoint> POINTS = Arrays.asList(new LabeledPoint(0.0, Vectors.dense((double)1.0, (double[])new double[]{0.0, 0.0})), new LabeledPoint(0.0, Vectors.dense((double)2.0, (double[])new double[]{0.0, 0.0})), new LabeledPoint(1.0, Vectors.dense((double)0.0, (double[])new double[]{1.0, 0.0})), new LabeledPoint(1.0, Vectors.dense((double)0.0, (double[])new double[]{2.0, 0.0})), new LabeledPoint(2.0, Vectors.dense((double)0.0, (double[])new double[]{0.0, 1.0})), new LabeledPoint(2.0, Vectors.dense((double)0.0, (double[])new double[]{0.0, 2.0})));

    private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
        int correct = 0;
        for (LabeledPoint p : points) {
            if (model.predict(p.features()) != p.label()) continue;
            ++correct;
        }
        return correct;
    }

    @Test
    public void runUsingConstructor() {
        JavaRDD testRDD = this.jsc.parallelize(POINTS, 2).cache();
        NaiveBayes nb = new NaiveBayes().setLambda(1.0);
        NaiveBayesModel model = nb.run(testRDD.rdd());
        int numAccurate = this.validatePrediction(POINTS, model);
        Assert.assertEquals((long)POINTS.size(), (long)numAccurate);
    }

    @Test
    public void runUsingStaticMethods() {
        JavaRDD testRDD = this.jsc.parallelize(POINTS, 2).cache();
        NaiveBayesModel model1 = NaiveBayes.train((RDD)testRDD.rdd());
        int numAccurate1 = this.validatePrediction(POINTS, model1);
        Assert.assertEquals((long)POINTS.size(), (long)numAccurate1);
        NaiveBayesModel model2 = NaiveBayes.train((RDD)testRDD.rdd(), (double)0.5);
        int numAccurate2 = this.validatePrediction(POINTS, model2);
        Assert.assertEquals((long)POINTS.size(), (long)numAccurate2);
    }

    @Test
    public void testPredictJavaRDD() {
        JavaRDD examples = this.jsc.parallelize(POINTS, 2).cache();
        NaiveBayesModel model = NaiveBayes.train((RDD)examples.rdd());
        JavaRDD vectors = examples.map((Function)new Function<LabeledPoint, Vector>(){

            public Vector call(LabeledPoint v) throws Exception {
                return v.features();
            }
        });
        JavaRDD predictions = model.predict(vectors);
        predictions.first();
    }

    @Test
    public void testModelTypeSetters() {
        NaiveBayes nb = new NaiveBayes().setModelType("bernoulli").setModelType("multinomial");
    }
}

