package com.github.log0ymxm.mapper.mnist;

import breeze.linalg.DenseVector;
import com.github.log0ymxm.mapper.Mapper$;
import org.apache.log4j.Level;
import org.apache.log4j.LogManager;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.IndexedRow;
import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.rdd.RDD;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: Driver.scala */
/* loaded from: input_file:com/github/log0ymxm/mapper/mnist/MNISTDriver$.class */
public final class MNISTDriver$ {
    public static final MNISTDriver$ MODULE$ = null;

    static {
        new MNISTDriver$();
    }

    public void main(String[] strArr) {
        RDD repartition;
        RDD rdd;
        LogManager.getLogger("org").setLevel(Level.WARN);
        boolean contains = Predef$.MODULE$.refArrayOps(strArr).contains("--local");
        SparkConf appName = new SparkConf().setAppName("Mnist Mapper");
        if (contains) {
            appName.setMaster("local[6]");
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        SparkContext sparkContext = new SparkContext(appName);
        if (contains) {
            RDD<DenseVector<Object>> fetchMnist = MNISTData$.MODULE$.fetchMnist(sparkContext, "/Users/penglish", "train");
            repartition = fetchMnist.repartition(500, fetchMnist.repartition$default$2(500));
        } else {
            RDD map = sparkContext.textFile("s3n://frst-nyc-data/train-mnist-dense-with-labels.data", sparkContext.textFile$default$2()).map(new MNISTDriver$$anonfun$1(), ClassTag$.MODULE$.apply(DenseVector.class));
            repartition = map.repartition(500, map.repartition$default$2(500));
        }
        RDD rdd2 = repartition;
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"--- mnistTrain num partitions ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(Predef$.MODULE$.refArrayOps(rdd2.partitions()).size())})));
        Predef$.MODULE$.println("Removing labels");
        RDD map2 = rdd2.map(new MNISTDriver$$anonfun$2(), ClassTag$.MODULE$.apply(DenseVector.class));
        Predef$.MODULE$.println("Converting to mllib DenseVectors");
        RDD map3 = map2.map(new MNISTDriver$$anonfun$3(), ClassTag$.MODULE$.apply(Vector.class));
        boolean contains2 = Predef$.MODULE$.refArrayOps(strArr).contains("--sample");
        if (contains2) {
            Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Sampling MNIST data"})).s(Nil$.MODULE$));
        }
        if (true == contains2) {
            rdd = map3.sample(true, 0.01d, map3.sample$default$3());
        } else {
            if (false != contains2) {
                throw new MatchError(BoxesRunTime.boxToBoolean(contains2));
            }
            rdd = map3;
        }
        RDD rdd3 = rdd;
        Predef$.MODULE$.println("Adding index");
        RDD map4 = rdd3.zipWithIndex().map(new MNISTDriver$$anonfun$4(), ClassTag$.MODULE$.apply(IndexedRow.class));
        Predef$.MODULE$.println("Converting to indexed matrix");
        IndexedRowMatrix indexedRowMatrix = new IndexedRowMatrix(map4);
        Predef$.MODULE$.println("Calculating Similarities");
        CoordinateMatrix columnSimilarities = indexedRowMatrix.toCoordinateMatrix().transpose().toIndexedRowMatrix().columnSimilarities();
        Predef$.MODULE$.println("Converting to distances");
        CoordinateMatrix coordinateMatrix = new CoordinateMatrix(columnSimilarities.entries().map(new MNISTDriver$$anonfun$5(), ClassTag$.MODULE$.apply(MatrixEntry.class)));
        Predef$.MODULE$.println("Calculating filtration");
        IndexedRowMatrix indexedRowMatrix2 = new IndexedRowMatrix(map4.map(new MNISTDriver$$anonfun$6(), ClassTag$.MODULE$.apply(IndexedRow.class)));
        Predef$.MODULE$.println(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"MNIST Count: ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToLong(rdd3.count())})));
        Predef$.MODULE$.println("Running Mapper");
        Mapper$.MODULE$.writeAsJson(Mapper$.MODULE$.mapper(sparkContext, coordinateMatrix, indexedRowMatrix2, 500, Mapper$.MODULE$.mapper$default$5()), "s3n://frst-nyc-data/graph.json");
        sparkContext.stop();
    }

    private MNISTDriver$() {
        MODULE$ = this;
    }
}
