package org.deeplearning4j.spark.util;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.InputStreamInputSplit;
import org.canova.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/util/MLLibUtil.class */
public class MLLibUtil {
    public static double toClassifierPrediction(Vector vector) {
        double d = Double.NEGATIVE_INFINITY;
        int i = 0;
        for (int i2 = 0; i2 < vector.size(); i2++) {
            double apply = vector.apply(i2);
            if (apply > d) {
                i = i2;
                d = apply;
            }
        }
        return i;
    }

    public static INDArray toMatrix(Matrix matrix) {
        return Nd4j.create(matrix.toArray(), new int[]{matrix.numRows(), matrix.numCols()});
    }

    public static INDArray toVector(Vector vector) {
        return Nd4j.create(Nd4j.createBuffer(vector.toArray()));
    }

    public static Matrix toMatrix(INDArray iNDArray) {
        if (iNDArray.isMatrix()) {
            return Matrices.dense(iNDArray.rows(), iNDArray.columns(), iNDArray.data().asDouble());
        }
        throw new IllegalArgumentException("passed in array must be a matrix");
    }

    public static Vector toVector(INDArray iNDArray) {
        if (!iNDArray.isVector()) {
            throw new IllegalArgumentException("passed in array must be a vector");
        }
        double[] dArr = new double[iNDArray.length()];
        for (int i = 0; i < iNDArray.length(); i++) {
            dArr[i] = iNDArray.getDouble(i);
        }
        return Vectors.dense(dArr);
    }

    public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> javaPairRDD, final RecordReader recordReader) {
        return javaPairRDD.map(new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.1
            public Collection<Writable> call(Tuple2<String, PortableDataStream> tuple2) throws Exception {
                recordReader.initialize(new InputStreamInputSplit(((PortableDataStream) tuple2._2()).open(), (String) tuple2._1()));
                return recordReader.next();
            }
        }).map(new Function<Collection<Writable>, LabeledPoint>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.2
            public LabeledPoint call(Collection<Writable> collection) throws Exception {
                return MLLibUtil.pointOf(collection);
            }
        });
    }

    public static LabeledPoint pointOf(Collection<Writable> collection) {
        double[] dArr = new double[collection.size() - 1];
        int i = 0;
        double d = 0.0d;
        for (Writable writable : collection) {
            if (i < collection.size() - 1) {
                int i2 = i;
                i++;
                dArr[i2] = Float.parseFloat(writable.toString());
            } else {
                d = Float.parseFloat(writable.toString());
            }
        }
        if (d < 0.0d) {
            throw new IllegalStateException("Target must be >= 0");
        }
        return new LabeledPoint(d, Vectors.dense(dArr));
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> javaRDD, final int i, int i2) {
        JavaPairRDD mapToPair = javaRDD.zipWithIndex().mapToPair(new PairFunction<Tuple2<LabeledPoint, Long>, Long, LabeledPoint>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.3
            public Tuple2<Long, LabeledPoint> call(Tuple2<LabeledPoint, Long> tuple2) throws Exception {
                return new Tuple2<>(tuple2._2(), tuple2._1());
            }
        }).mapToPair(new PairFunction<Tuple2<Long, LabeledPoint>, Long, DataSet>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.4
            public Tuple2<Long, DataSet> call(Tuple2<Long, LabeledPoint> tuple2) throws Exception {
                return new Tuple2<>(tuple2._1(), MLLibUtil.fromLabeledPoint((LabeledPoint) tuple2._2(), i));
            }
        });
        return mapToPair.reduceByKey(new Function2<DataSet, DataSet, DataSet>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.5
            public DataSet call(DataSet dataSet, DataSet dataSet2) throws Exception {
                return new DataSet(Nd4j.vstack(new INDArray[]{dataSet.getFeatureMatrix(), dataSet2.getFeatureMatrix()}), Nd4j.vstack(new INDArray[]{dataSet.getLabels(), dataSet2.getLabels()}));
            }
        }, (int) (mapToPair.count() / i2)).flatMap(new FlatMapFunction<Tuple2<Long, DataSet>, DataSet>() { // from class: org.deeplearning4j.spark.util.MLLibUtil.6
            public Iterable<DataSet> call(Tuple2<Long, DataSet> tuple2) throws Exception {
                return (Iterable) tuple2._2();
            }
        });
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD, int i) {
        return javaSparkContext.parallelize(fromLabeledPoint((List<LabeledPoint>) javaRDD.collect(), i));
    }

    public static JavaRDD<LabeledPoint> fromDataSet(JavaSparkContext javaSparkContext, JavaRDD<DataSet> javaRDD) {
        return javaSparkContext.parallelize(toLabeledPoint((List<DataSet>) javaRDD.collect()));
    }

    private static List<LabeledPoint> toLabeledPoint(List<DataSet> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<DataSet> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(toLabeledPoint(it.next()));
        }
        return arrayList;
    }

    private static LabeledPoint toLabeledPoint(DataSet dataSet) {
        if (!dataSet.getFeatureMatrix().isVector()) {
            throw new IllegalArgumentException("Feature matrix must be a vector");
        }
        return new LabeledPoint(Nd4j.getBlasWrapper().iamax(dataSet.getLabels()), toVector(dataSet.getFeatureMatrix().dup()));
    }

    private static List<DataSet> fromLabeledPoint(List<LabeledPoint> list, int i) {
        ArrayList arrayList = new ArrayList();
        Iterator<LabeledPoint> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(fromLabeledPoint(it.next(), i));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static DataSet fromLabeledPoint(LabeledPoint labeledPoint, int i) {
        return new DataSet(Nd4j.create(labeledPoint.features().toArray()), FeatureUtil.toOutcomeVector((int) labeledPoint.label(), i));
    }
}
