/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.regression;

import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.rdd.RDD;
import org.junit.Assert;
import org.junit.Test;

public class JavaLinearRegressionSuite
extends SharedSparkSession {
    int validatePrediction(List<LabeledPoint> validationData, LinearRegressionModel model) {
        int numAccurate = 0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            if (!(Math.abs(prediction - point.label()) <= 0.5)) continue;
            ++numAccurate;
        }
        return numAccurate;
    }

    @Test
    public void runLinearRegressionUsingConstructor() {
        int nPoints = 100;
        double A = 3.0;
        double[] weights = new double[]{10.0, 10.0};
        JavaRDD testRDD = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)42, (double)0.1), 2).cache();
        List validationData = LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)17, (double)0.1);
        LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
        linSGDImpl.setIntercept(true);
        LinearRegressionModel model = (LinearRegressionModel)linSGDImpl.run(testRDD.rdd());
        int numAccurate = this.validatePrediction(validationData, model);
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }

    @Test
    public void runLinearRegressionUsingStaticMethods() {
        LinearRegressionModel model;
        int nPoints = 100;
        double A = 0.0;
        double[] weights = new double[]{10.0, 10.0};
        JavaRDD testRDD = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)42, (double)0.1), 2).cache();
        List validationData = LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)17, (double)0.1);
        int numAccurate = this.validatePrediction(validationData, model = LinearRegressionWithSGD.train((RDD)testRDD.rdd(), (int)100));
        Assert.assertTrue(((double)numAccurate > (double)nPoints * 4.0 / 5.0 ? 1 : 0) != 0);
    }

    @Test
    public void testPredictJavaRDD() {
        int nPoints = 100;
        double A = 0.0;
        double[] weights = new double[]{10.0, 10.0};
        JavaRDD testRDD = this.jsc.parallelize(LinearDataGenerator.generateLinearInputAsList((double)A, (double[])weights, (int)nPoints, (int)42, (double)0.1), 2).cache();
        LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD();
        LinearRegressionModel model = (LinearRegressionModel)linSGDImpl.run(testRDD.rdd());
        JavaRDD vectors = testRDD.map((Function)new Function<LabeledPoint, Vector>(){

            public Vector call(LabeledPoint v) throws Exception {
                return v.features();
            }
        });
        JavaRDD predictions = model.predict(vectors);
        predictions.first();
    }
}

