package org.deeplearning4j.spark.util;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Array;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import lombok.NonNull;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RemoteIterator;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.data.BatchDataSetsFunction;
import org.deeplearning4j.spark.data.shuffle.SplitDataSetExamplesPairFlatMapFunction;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.HashingBalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.deeplearning4j.spark.impl.repartitioner.EqualRepartitioner;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import scala.reflect.ClassTag;

/* loaded from: input_file:org/deeplearning4j/spark/util/SparkUtils.class */
public class SparkUtils {
    private static final Logger log = LoggerFactory.getLogger(SparkUtils.class);
    private static final String KRYO_EXCEPTION_MSG = "Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto#kryo for more details";
    private static String sparkExecutorId;

    private SparkUtils() {
    }

    public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger logger) {
        String str = javaSparkContext.getConf().get("spark.serializer", (String) null);
        if (str == null || !str.equals("org.apache.spark.serializer.KryoSerializer")) {
            return true;
        }
        String str2 = javaSparkContext.getConf().get("spark.kryo.registrator", (String) null);
        if (str2 != null && str2.equals("org.nd4j.Nd4jRegistrator")) {
            return true;
        }
        try {
            SerializerInstance newInstance = javaSparkContext.env().serializer().newInstance();
            ByteBuffer serialize = newInstance.serialize(Nd4j.linspace(1L, 5L, 5L), (ClassTag) null);
            if (serialize == null) {
                throw new RuntimeException("Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto#kryo for more details\n(Got: null ByteBuffer from Spark SerializerInstance)");
            }
            try {
                if (Nd4j.linspace(1L, 5L, 5L).equals((INDArray) newInstance.deserialize(serialize, (ClassTag) null))) {
                    return true;
                }
                throw new RuntimeException("Kryo serialization detected without an appropriate registrator for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid serialization issues (NullPointerException) with off-heap data in INDArrays.\nUse nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.Nd4jRegistrator\");\nSee https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-howto#kryo for more details\n(Error during deserialization: test array was not deserialized successfully)");
            } catch (Exception e) {
                throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
            }
        } catch (Exception e2) {
            throw new RuntimeException(KRYO_EXCEPTION_MSG, e2);
        }
    }

    public static void writeStringToFile(String str, String str2, JavaSparkContext javaSparkContext) throws IOException {
        writeStringToFile(str, str2, javaSparkContext.sc());
    }

    public static void writeStringToFile(String str, String str2, SparkContext sparkContext) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(FileSystem.get(sparkContext.hadoopConfiguration()).create(new Path(str)));
        Throwable th = null;
        try {
            try {
                bufferedOutputStream.write(str2.getBytes("UTF-8"));
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public static String readStringFromFile(String str, JavaSparkContext javaSparkContext) throws IOException {
        return readStringFromFile(str, javaSparkContext.sc());
    }

    public static String readStringFromFile(String str, SparkContext sparkContext) throws IOException {
        BufferedInputStream bufferedInputStream = new BufferedInputStream(FileSystem.get(sparkContext.hadoopConfiguration()).open(new Path(str)));
        Throwable th = null;
        try {
            try {
                String str2 = new String(IOUtils.toByteArray(bufferedInputStream), "UTF-8");
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                return str2;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (th != null) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static void writeObjectToFile(String str, Object obj, JavaSparkContext javaSparkContext) throws IOException {
        writeObjectToFile(str, obj, javaSparkContext.sc());
    }

    public static void writeObjectToFile(String str, Object obj, SparkContext sparkContext) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(FileSystem.get(sparkContext.hadoopConfiguration()).create(new Path(str)));
        Throwable th = null;
        try {
            new ObjectOutputStream(bufferedOutputStream).writeObject(obj);
            if (bufferedOutputStream != null) {
                if (0 == 0) {
                    bufferedOutputStream.close();
                    return;
                }
                try {
                    bufferedOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public static <T> T readObjectFromFile(String str, Class<T> cls, JavaSparkContext javaSparkContext) throws IOException {
        return (T) readObjectFromFile(str, cls, javaSparkContext.sc());
    }

    public static <T> T readObjectFromFile(String str, Class<T> cls, SparkContext sparkContext) throws IOException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(FileSystem.get(sparkContext.hadoopConfiguration()).open(new Path(str))));
        Throwable th = null;
        try {
            try {
                return (T) objectInputStream.readObject();
            } finally {
                if (objectInputStream != null) {
                    if (0 != 0) {
                        try {
                            objectInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectInputStream.close();
                    }
                }
            }
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> javaRDD, Repartition repartition, RepartitionStrategy repartitionStrategy, int i, int i2) {
        if (repartition == Repartition.Never) {
            return javaRDD;
        }
        switch (repartitionStrategy) {
            case SparkDefault:
                return (repartition == Repartition.NumPartitionsWorkersDiffers && javaRDD.partitions().size() == i2) ? javaRDD : javaRDD.repartition(i2);
            case Balanced:
                return repartitionBalanceIfRequired(javaRDD, repartition, i, i2);
            case ApproximateBalanced:
                return repartitionApproximateBalance(javaRDD, repartition, i2);
            default:
                throw new RuntimeException("Unknown repartition strategy: " + repartitionStrategy);
        }
    }

    public static <T> JavaRDD<T> repartitionApproximateBalance(JavaRDD<T> javaRDD, Repartition repartition, int i) {
        int size = javaRDD.partitions().size();
        switch (repartition) {
            case Never:
                return javaRDD;
            case NumPartitionsWorkersDiffers:
                if (size == i) {
                    return javaRDD;
                }
                break;
            case Always:
                break;
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
        Integer num = 0;
        Iterator it = javaRDD.mapPartitionsWithIndex(new Function2<Integer, Iterator<T>, Iterator<Integer>>() { // from class: org.deeplearning4j.spark.util.SparkUtils.1
            public Iterator<Integer> call(Integer num2, Iterator<T> it2) throws Exception {
                int i2 = 0;
                while (it2.hasNext()) {
                    it2.next();
                    i2++;
                }
                return Collections.singletonList(Integer.valueOf(i2)).iterator();
            }
        }, true).collect().iterator();
        while (it.hasNext()) {
            num = Integer.valueOf(num.intValue() + ((Integer) it.next()).intValue());
        }
        ArrayList arrayList = new ArrayList(Math.max(i, size));
        Double valueOf = Double.valueOf(num.intValue() / i);
        for (int i2 = 0; i2 < Math.min(size, i); i2++) {
            arrayList.add(Double.valueOf(((Integer) r0.get(i2)).intValue() / valueOf.doubleValue()));
        }
        for (int min = Math.min(size, i); min < Math.max(size, i); min++) {
            if (min >= i) {
                arrayList.add(Double.valueOf(-1.0d));
            } else {
                arrayList.add(Double.valueOf(0.0d));
            }
        }
        return javaRDD.zipWithUniqueId().mapToPair(new PairFunction<Tuple2<T, Long>, Tuple2<Long, Integer>, T>() { // from class: org.deeplearning4j.spark.util.SparkUtils.2
            public Tuple2<Tuple2<Long, Integer>, T> call(Tuple2<T, Long> tuple2) {
                return new Tuple2<>(new Tuple2(tuple2._2(), 0), tuple2._1());
            }
        }).partitionBy(new HashingBalancedPartitioner(Collections.singletonList(arrayList))).map(new Function<Tuple2<Tuple2<Long, Integer>, T>, T>() { // from class: org.deeplearning4j.spark.util.SparkUtils.3
            public T call(Tuple2<Tuple2<Long, Integer>, T> tuple2) {
                return (T) tuple2._2();
            }
        });
    }

    public static <T> JavaRDD<T> repartitionBalanceIfRequired(JavaRDD<T> javaRDD, Repartition repartition, int i, int i2) {
        int size = javaRDD.partitions().size();
        switch (repartition) {
            case Never:
                return javaRDD;
            case NumPartitionsWorkersDiffers:
                if (size == i2) {
                    return javaRDD;
                }
                break;
            case Always:
                break;
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
        List<Tuple2> collect = javaRDD.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect();
        int i3 = 0;
        int size2 = collect.size();
        boolean z = true;
        for (Tuple2 tuple2 : collect) {
            z &= ((Integer) tuple2._2()).intValue() == i;
            i3 += ((Integer) tuple2._2()).intValue();
        }
        if (i2 * i < i3) {
            z = true;
            Iterator it = collect.iterator();
            while (it.hasNext()) {
                z &= ((Integer) ((Tuple2) it.next())._2()).intValue() == i;
            }
        }
        if (size2 == i2 && z) {
            return javaRDD;
        }
        JavaPairRDD indexedRDD = indexedRDD(javaRDD);
        int i4 = (i3 - (i2 * i)) % i2;
        log.debug("About to rebalance: numPartitions={}, objectsPerPartition={}, remainder={}", new Object[]{Integer.valueOf(i2), Integer.valueOf(i), Integer.valueOf(i4)});
        return indexedRDD.partitionBy(new BalancedPartitioner(i2, i, i4)).values();
    }

    public static <T> JavaPairRDD<Integer, T> indexedRDD(JavaRDD<T> javaRDD) {
        return javaRDD.zipWithIndex().mapToPair(new PairFunction<Tuple2<T, Long>, Integer, T>() { // from class: org.deeplearning4j.spark.util.SparkUtils.4
            public Tuple2<Integer, T> call(Tuple2<T, Long> tuple2) {
                return new Tuple2<>(Integer.valueOf(((Long) tuple2._2()).intValue()), tuple2._1());
            }
        });
    }

    public static <T> JavaRDD<T> repartitionEqually(JavaRDD<T> javaRDD, Repartition repartition, int i) {
        int size = javaRDD.partitions().size();
        switch (repartition) {
            case Never:
                return javaRDD;
            case NumPartitionsWorkersDiffers:
                if (size == i) {
                    return javaRDD;
                }
                break;
            case Always:
                break;
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
        return new EqualRepartitioner().repartition(javaRDD, -1, i);
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int i, int i2, JavaRDD<T> javaRDD) {
        return balancedRandomSplit(i, i2, javaRDD, new Random().nextLong());
    }

    public static <T> JavaRDD<T>[] balancedRandomSplit(int i, int i2, JavaRDD<T> javaRDD, long j) {
        JavaRDD<T>[] javaRDDArr;
        if (i <= i2) {
            javaRDDArr = (JavaRDD[]) Array.newInstance((Class<?>) JavaRDD.class, 1);
            javaRDDArr[0] = javaRDD;
        } else {
            int i3 = i / i2;
            javaRDDArr = (JavaRDD[]) Array.newInstance((Class<?>) JavaRDD.class, i3);
            for (int i4 = 0; i4 < i3; i4++) {
                javaRDDArr[i4] = javaRDD.mapPartitionsWithIndex(new SplitPartitionsFunction(i4, i3, j), true);
            }
        }
        return javaRDDArr;
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int i, int i2, JavaPairRDD<T, U> javaPairRDD) {
        return balancedRandomSplit(i, i2, javaPairRDD, new Random().nextLong());
    }

    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int i, int i2, JavaPairRDD<T, U> javaPairRDD, long j) {
        JavaPairRDD<T, U>[] javaPairRDDArr;
        if (i <= i2) {
            javaPairRDDArr = (JavaPairRDD[]) Array.newInstance((Class<?>) JavaPairRDD.class, 1);
            javaPairRDDArr[0] = javaPairRDD;
        } else {
            int i3 = i / i2;
            javaPairRDDArr = (JavaPairRDD[]) Array.newInstance((Class<?>) JavaPairRDD.class, i3);
            for (int i4 = 0; i4 < i3; i4++) {
                javaPairRDDArr[i4] = javaPairRDD.mapPartitionsWithIndex(new SplitPartitionsFunction2(i4, i3, j), true).mapPartitionsToPair(new MapTupleToPairFlatMap(), true);
            }
        }
        return javaPairRDDArr;
    }

    public static JavaRDD<String> listPaths(JavaSparkContext javaSparkContext, String str) throws IOException {
        return listPaths(javaSparkContext, str, false);
    }

    public static JavaRDD<String> listPaths(JavaSparkContext javaSparkContext, String str, boolean z) throws IOException {
        return listPaths(javaSparkContext, str, z, (Set<String>) null);
    }

    public static JavaRDD<String> listPaths(JavaSparkContext javaSparkContext, String str, boolean z, String[] strArr) throws IOException {
        return listPaths(javaSparkContext, str, z, strArr == null ? null : new HashSet(Arrays.asList(strArr)));
    }

    public static JavaRDD<String> listPaths(JavaSparkContext javaSparkContext, String str, boolean z, Set<String> set) throws IOException {
        return listPaths(javaSparkContext, str, z, set, javaSparkContext.hadoopConfiguration());
    }

    public static JavaRDD<String> listPaths(@NonNull JavaSparkContext javaSparkContext, String str, boolean z, Set<String> set, @NonNull Configuration configuration) throws IOException {
        if (javaSparkContext == null) {
            throw new NullPointerException("sc is marked @NonNull but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        ArrayList arrayList = new ArrayList();
        RemoteIterator listFiles = FileSystem.get(URI.create(str), configuration).listFiles(new Path(str), z);
        while (listFiles.hasNext()) {
            String path = ((LocatedFileStatus) listFiles.next()).getPath().toString();
            if (set == null) {
                arrayList.add(path);
            } else if (set.contains(FilenameUtils.getExtension(str))) {
                arrayList.add(path);
            }
        }
        return javaSparkContext.parallelize(arrayList);
    }

    public static JavaRDD<DataSet> shuffleExamples(JavaRDD<DataSet> javaRDD, int i, int i2) {
        return javaRDD.flatMapToPair(new SplitDataSetExamplesPairFlatMapFunction(i2)).partitionBy(new HashPartitioner(i2)).values().mapPartitions(new BatchDataSetsFunction(i));
    }

    public static String getSparkExecutorId() {
        if (sparkExecutorId != null) {
            return sparkExecutorId;
        }
        synchronized (SparkUtils.class) {
            if (sparkExecutorId != null) {
                return sparkExecutorId;
            }
            String property = System.getProperty("sun.java.command");
            if (property == null || property.isEmpty() || !property.contains("executor-id")) {
                sparkExecutorId = UIDProvider.getJVMUID();
                return sparkExecutorId;
            }
            String[] split = property.substring(property.indexOf("executor-id")).split(" ");
            if (split.length < 2) {
                sparkExecutorId = UIDProvider.getJVMUID();
                return sparkExecutorId;
            }
            sparkExecutorId = split[1];
            return sparkExecutorId;
        }
    }

    public static Broadcast<byte[]> asByteArrayBroadcast(JavaSparkContext javaSparkContext, INDArray iNDArray) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        try {
            Nd4j.write(iNDArray, new DataOutputStream(byteArrayOutputStream));
            return javaSparkContext.broadcast(byteArrayOutputStream.toByteArray());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
