/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.recommendation;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.ALSSuite;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import org.junit.Assert;
import org.junit.Test;
import scala.Tuple2;
import scala.Tuple3;

public class JavaALSSuite
extends SharedSparkSession {
    private void validatePrediction(MatrixFactorizationModel model, int users, int products, double[] trueRatings, double matchThreshold, boolean implicitPrefs, double[] truePrefs) {
        ArrayList<Tuple2> localUsersProducts = new ArrayList<Tuple2>(users * products);
        for (int u = 0; u < users; ++u) {
            for (int p = 0; p < products; ++p) {
                localUsersProducts.add(new Tuple2((Object)u, (Object)p));
            }
        }
        JavaPairRDD usersProducts = this.jsc.parallelizePairs(localUsersProducts);
        List predictedRatings = model.predict(usersProducts).collect();
        Assert.assertEquals((long)(users * products), (long)predictedRatings.size());
        if (!implicitPrefs) {
            for (Rating r : predictedRatings) {
                double prediction = r.rating();
                double correct = trueRatings[r.product() * users + r.user()];
                Assert.assertTrue((String)String.format("Prediction=%2.4f not below match threshold of %2.2f", prediction, matchThreshold), (Math.abs(prediction - correct) < matchThreshold ? 1 : 0) != 0);
            }
        } else {
            double sqErr = 0.0;
            double denom = 0.0;
            for (Rating r : predictedRatings) {
                double prediction = r.rating();
                double truePref = truePrefs[r.product() * users + r.user()];
                double confidence = 1.0 + Math.abs(trueRatings[r.product() * users + r.user()]);
                double err = confidence * (truePref - prediction) * (truePref - prediction);
                sqErr += err;
                denom += confidence;
            }
            double rmse = Math.sqrt(sqErr / denom);
            Assert.assertTrue((String)String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", rmse, matchThreshold), (rmse < matchThreshold ? 1 : 0) != 0);
        }
    }

    @Test
    public void runALSUsingStaticMethods() {
        int features = 1;
        int iterations = 15;
        int users = 50;
        int products = 100;
        Tuple3<List<Rating>, double[], double[]> testData = ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
        JavaRDD data = this.jsc.parallelize((List)testData._1());
        MatrixFactorizationModel model = ALS.train((RDD)data.rdd(), (int)features, (int)iterations);
        this.validatePrediction(model, users, products, (double[])testData._2(), 0.3, false, (double[])testData._3());
    }

    @Test
    public void runALSUsingConstructor() {
        int features = 2;
        int iterations = 15;
        int users = 100;
        int products = 200;
        Tuple3<List<Rating>, double[], double[]> testData = ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false);
        JavaRDD data = this.jsc.parallelize((List)testData._1());
        MatrixFactorizationModel model = new ALS().setRank(features).setIterations(iterations).run(data);
        this.validatePrediction(model, users, products, (double[])testData._2(), 0.3, false, (double[])testData._3());
    }

    @Test
    public void runImplicitALSUsingStaticMethods() {
        int features = 1;
        int iterations = 15;
        int users = 80;
        int products = 160;
        Tuple3<List<Rating>, double[], double[]> testData = ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
        JavaRDD data = this.jsc.parallelize((List)testData._1());
        MatrixFactorizationModel model = ALS.trainImplicit((RDD)data.rdd(), (int)features, (int)iterations);
        this.validatePrediction(model, users, products, (double[])testData._2(), 0.4, true, (double[])testData._3());
    }

    @Test
    public void runImplicitALSUsingConstructor() {
        int features = 2;
        int iterations = 15;
        int users = 100;
        int products = 200;
        Tuple3<List<Rating>, double[], double[]> testData = ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false);
        JavaRDD data = this.jsc.parallelize((List)testData._1());
        MatrixFactorizationModel model = new ALS().setRank(features).setIterations(iterations).setImplicitPrefs(true).run(data.rdd());
        this.validatePrediction(model, users, products, (double[])testData._2(), 0.4, true, (double[])testData._3());
    }

    @Test
    public void runImplicitALSWithNegativeWeight() {
        int features = 2;
        int iterations = 15;
        int users = 80;
        int products = 160;
        Tuple3<List<Rating>, double[], double[]> testData = ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true);
        JavaRDD data = this.jsc.parallelize((List)testData._1());
        MatrixFactorizationModel model = new ALS().setRank(features).setIterations(iterations).setImplicitPrefs(true).setSeed(8675309L).run(data.rdd());
        this.validatePrediction(model, users, products, (double[])testData._2(), 0.4, true, (double[])testData._3());
    }

    @Test
    public void runRecommend() {
        int features = 5;
        int iterations = 10;
        int users = 200;
        int products = 50;
        List testData = (List)ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false)._1();
        JavaRDD data = this.jsc.parallelize(testData);
        MatrixFactorizationModel model = new ALS().setRank(features).setIterations(iterations).setImplicitPrefs(true).setSeed(8675309L).run(data.rdd());
        JavaALSSuite.validateRecommendations(model.recommendProducts(1, 10), 10);
        JavaALSSuite.validateRecommendations(model.recommendUsers(1, 20), 20);
    }

    private static void validateRecommendations(Rating[] recommendations, int howMany) {
        Assert.assertEquals((long)howMany, (long)recommendations.length);
        for (int i = 1; i < recommendations.length; ++i) {
            Assert.assertTrue((recommendations[i - 1].rating() >= recommendations[i].rating() ? 1 : 0) != 0);
        }
        Assert.assertTrue((recommendations[0].rating() > 0.7 ? 1 : 0) != 0);
    }
}

