package org.deeplearning4j.spark.impl.multilayer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Adder;
import org.deeplearning4j.spark.impl.common.BestScoreAccumulator;
import org.deeplearning4j.spark.impl.common.gradient.GradientAdder;
import org.deeplearning4j.spark.impl.common.misc.GradientFromTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.INDArrayFromTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.ScoreReport;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromGradientTupleFunction;
import org.deeplearning4j.spark.impl.common.misc.UpdaterFromTupleFunction;
import org.deeplearning4j.spark.impl.common.updater.UpdaterAggregatorCombiner;
import org.deeplearning4j.spark.impl.common.updater.UpdaterElementCombiner;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluationReduceFunction;
import org.deeplearning4j.spark.impl.multilayer.gradientaccum.GradientAccumFlatMap;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple3;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.class */
public class SparkDl4jMultiLayer implements Serializable {
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 50;
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;
    private Broadcast<INDArray> params;
    private Broadcast<Updater> updater;
    private boolean averageEachIteration;
    public static final String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average";
    public static final String ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.accumgrad";
    public static final String DIVIDE_ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.dividegrad";
    private Accumulator<Double> bestScoreAcc;
    private double lastScore;
    private transient boolean initDone;
    private transient AtomicInteger iterationsCount;
    private List<IterationListener> listeners;
    private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer$SMapping.class */
    public static class SMapping implements DoubleFunction<Tuple3<INDArray, Updater, ScoreReport>> {
        private SMapping() {
        }

        public double call(Tuple3<INDArray, Updater, ScoreReport> tuple3) throws Exception {
            return ((ScoreReport) tuple3._3()).getM();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer$SMappingG.class */
    public static class SMappingG implements DoubleFunction<Tuple3<Gradient, Updater, ScoreReport>> {
        private SMappingG() {
        }

        public double call(Tuple3<Gradient, Updater, ScoreReport> tuple3) throws Exception {
            return ((ScoreReport) tuple3._3()).getM();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer$ScoreMapping.class */
    public static class ScoreMapping implements DoubleFunction<Tuple3<INDArray, Updater, ScoreReport>> {
        private ScoreMapping() {
        }

        public double call(Tuple3<INDArray, Updater, ScoreReport> tuple3) throws Exception {
            return ((ScoreReport) tuple3._3()).getS();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer$ScoreMappingG.class */
    public static class ScoreMappingG implements DoubleFunction<Tuple3<Gradient, Updater, ScoreReport>> {
        private ScoreMappingG() {
        }

        public double call(Tuple3<Gradient, Updater, ScoreReport> tuple3) throws Exception {
            return ((ScoreReport) tuple3._3()).getS();
        }
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork multiLayerNetwork) {
        this(new JavaSparkContext(sparkContext), multiLayerNetwork);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork multiLayerNetwork) {
        this.averageEachIteration = false;
        this.bestScoreAcc = null;
        this.initDone = false;
        this.iterationsCount = new AtomicInteger(0);
        this.listeners = new ArrayList();
        this.sparkContext = javaSparkContext.sc();
        this.sc = javaSparkContext;
        this.conf = multiLayerNetwork.getLayerWiseConfigurations().clone();
        this.network = multiLayerNetwork;
        this.network.init();
        this.updater = this.sc.broadcast(multiLayerNetwork.getUpdater());
        this.averageEachIteration = this.sparkContext.conf().getBoolean("org.deeplearning4j.spark.iteration.average", false);
        this.bestScoreAcc = BestScoreAccumulator.create(this.sparkContext);
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this.averageEachIteration = false;
        this.bestScoreAcc = null;
        this.initDone = false;
        this.iterationsCount = new AtomicInteger(0);
        this.listeners = new ArrayList();
        this.sparkContext = sparkContext;
        this.sc = new JavaSparkContext(this.sparkContext);
        this.conf = multiLayerConfiguration.clone();
        this.network = new MultiLayerNetwork(multiLayerConfiguration);
        this.network.init();
        this.averageEachIteration = sparkContext.conf().getBoolean("org.deeplearning4j.spark.iteration.average", false);
        this.bestScoreAcc = BestScoreAccumulator.create(sparkContext);
        this.updater = this.sc.broadcast(this.network.getUpdater());
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this(javaSparkContext.sc(), multiLayerConfiguration);
    }

    public MultiLayerNetwork fit(String str, int i, RecordReader recordReader) {
        return fitDataSet(loadFromTextFile(str, i, recordReader));
    }

    public MultiLayerNetwork fit(String str, int i, RecordReader recordReader, int i2, int i3) {
        JavaRDD<DataSet> loadFromTextFile = loadFromTextFile(str, i, recordReader);
        loadFromTextFile.cache();
        return fitDataSet(loadFromTextFile, i2, (int) loadFromTextFile.count(), i3);
    }

    public MultiLayerNetwork fit(String str, int i, RecordReader recordReader, int i2, int i3, int i4) {
        return fitDataSet(loadFromTextFile(str, i, recordReader), i2, i3, i4);
    }

    private JavaRDD<DataSet> loadFromTextFile(String str, int i, RecordReader recordReader) {
        return this.sc.textFile(str).map(new RecordReaderFunction(recordReader, i, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut()));
    }

    public MultiLayerNetwork getNetwork() {
        return this.network;
    }

    public void setNetwork(MultiLayerNetwork multiLayerNetwork) {
        this.network = multiLayerNetwork;
    }

    public Matrix predict(Matrix matrix) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(matrix)));
    }

    public Vector predict(Vector vector) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(vector)));
    }

    public MultiLayerNetwork fit(JavaRDD<LabeledPoint> javaRDD, int i) {
        return fitDataSet(MLLibUtil.fromLabeledPoint(javaRDD, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut(), i));
    }

    public MultiLayerNetwork fit(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD) {
        return fitDataSet(MLLibUtil.fromLabeledPoint(javaSparkContext, javaRDD, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut()));
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> javaRDD, int i, int i2) {
        javaRDD.cache();
        return fitDataSet(javaRDD, i, (int) javaRDD.count(), i2);
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> javaRDD, int i, int i2, int i3) {
        int i4 = (i == Integer.MAX_VALUE || i >= i2) ? 1 : i2 % i == 0 ? i2 / i : (i2 / i) + 1;
        if (i4 == 1) {
            fitDataSet(javaRDD);
        } else {
            double[] dArr = new double[i4];
            for (int i5 = 0; i5 < i4; i5++) {
                dArr[i5] = 1.0d / i4;
            }
            JavaRDD[] randomSplit = javaRDD.randomSplit(dArr);
            for (int i6 = 0; i6 < randomSplit.length; i6++) {
                log.info("Initiating distributed training of subset {} of {}", Integer.valueOf(i6 + 1), Integer.valueOf(randomSplit.length));
                fitDataSet(randomSplit[i6].repartition(i3));
            }
        }
        return this.network;
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> javaRDD) {
        int numIterations = this.conf.getConf(0).getNumIterations();
        log.info("Running distributed training:  (averaging each iteration = " + this.averageEachIteration + "), (iterations = " + numIterations + "), (num partions = " + javaRDD.partitions().size() + ")");
        if (this.averageEachIteration) {
            Iterator it = this.conf.getConfs().iterator();
            while (it.hasNext()) {
                ((NeuralNetConfiguration) it.next()).setNumIterations(1);
            }
            for (int i = 0; i < numIterations; i++) {
                runIteration(javaRDD);
            }
            if (numIterations > 1) {
                Iterator it2 = this.conf.getConfs().iterator();
                while (it2.hasNext()) {
                    ((NeuralNetConfiguration) it2.next()).setNumIterations(numIterations);
                }
            }
        } else {
            runIteration(javaRDD);
        }
        return this.network;
    }

    protected void runIteration(JavaRDD<DataSet> javaRDD) {
        int i = 0;
        long j = 0;
        int numParams = this.network.numParams(false);
        log.info("Broadcasting initial parameters of length " + numParams);
        this.params = this.sc.broadcast(this.network.params(false));
        Updater updater = this.network.getUpdater();
        if (updater == null) {
            this.network.setUpdater(UpdaterCreator.getUpdater(this.network));
            log.warn("Unable to propagate null updater");
            updater = this.network.getUpdater();
        }
        this.updater = this.sc.broadcast(updater);
        if (this.sc.getConf().getBoolean("org.deeplearning4j.spark.iteration.accumgrad", false)) {
            JavaRDD cache = javaRDD.mapPartitions(new GradientAccumFlatMap(this.conf.toJson(), this.params, this.updater), true).cache();
            JavaRDD map = cache.map(new GradientFromTupleFunction());
            log.info("Ran iterative reduce... averaging results now.");
            GradientAdder gradientAdder = new GradientAdder(numParams);
            map.foreach(gradientAdder);
            INDArray iNDArray = (INDArray) gradientAdder.getAccumulator().value();
            if (this.sc.getConf().getBoolean("org.deeplearning4j.spark.iteration.dividegrad", false)) {
                i = cache.partitions().size();
                iNDArray.divi(Integer.valueOf(i));
            }
            log.info("Accumulated parameters");
            log.info("Summed gradients.");
            this.network.setParameters(this.network.params(false).addi(iNDArray));
            log.info("Set parameters");
            this.lastScore = cache.mapToDouble(new ScoreMappingG()).mean().doubleValue();
            if (!this.initDone) {
                j = cache.mapToDouble(new SMappingG()).mean().longValue();
            }
            log.info("Processing updaters");
            JavaRDD map2 = cache.map(new UpdaterFromGradientTupleFunction());
            this.network.setUpdater(((UpdaterAggregator) map2.aggregate(((Updater) map2.first()).getAggregator(false), new UpdaterElementCombiner(), new UpdaterAggregatorCombiner())).getUpdater());
            log.info("Set updater");
        } else {
            JavaRDD cache2 = javaRDD.mapPartitions(new IterativeReduceFlatMap(this.conf.toJson(), this.params, this.updater, this.bestScoreAcc), true).cache();
            JavaRDD map3 = cache2.map(new INDArrayFromTupleFunction());
            log.info("Running iterative reduce and averaging parameters");
            Adder adder = new Adder(numParams, this.sc.accumulator(0));
            map3.foreach(adder);
            INDArray iNDArray2 = (INDArray) adder.getAccumulator().value();
            i = ((Integer) adder.getCounter().value()).intValue();
            iNDArray2.divi(Integer.valueOf(i));
            this.network.setParameters(iNDArray2);
            log.info("Accumulated and set parameters");
            this.lastScore = cache2.mapToDouble(new ScoreMapping()).mean().doubleValue();
            if (!this.initDone) {
                j = cache2.mapToDouble(new SMapping()).mean().longValue();
            }
            this.network.setUpdater(((UpdaterAggregator) cache2.map(new UpdaterFromTupleFunction()).aggregate((Object) null, new UpdaterElementCombiner(), new UpdaterAggregatorCombiner())).getUpdater());
            log.info("Processed and set updater");
        }
        if (this.listeners.size() > 0) {
            log.debug("Invoking IterationListeners");
            this.network.setScore(this.lastScore);
            invokeListeners(this.network, this.iterationsCount.incrementAndGet());
        }
        if (this.initDone) {
            return;
        }
        this.initDone = true;
        update(i, j);
    }

    public static MultiLayerNetwork train(JavaRDD<LabeledPoint> javaRDD, MultiLayerConfiguration multiLayerConfiguration) {
        return new SparkDl4jMultiLayer(javaRDD.context(), multiLayerConfiguration).fit(new JavaSparkContext(javaRDD.context()), javaRDD);
    }

    public void setListeners(@NonNull Collection<IterationListener> collection) {
        if (collection == null) {
            throw new NullPointerException("listeners");
        }
        this.listeners.clear();
        this.listeners.addAll(collection);
    }

    protected void invokeListeners(MultiLayerNetwork multiLayerNetwork, int i) {
        Iterator<IterationListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            try {
                it.next().iterationDone(multiLayerNetwork, i);
            } catch (Exception e) {
                log.error("Exception caught at IterationListener invocation" + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public double getScore() {
        return this.lastScore;
    }

    public double calculateScore(JavaRDD<DataSet> javaRDD, boolean z) {
        long count = javaRDD.count();
        double d = 0.0d;
        Iterator it = javaRDD.mapPartitions(new ScoreFlatMapFunction(this.conf.toJson(), this.sc.broadcast(this.network.params(false)))).collect().iterator();
        while (it.hasNext()) {
            d += ((Double) it.next()).doubleValue();
        }
        return z ? d / count : d;
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z) {
        return scoreExamples(javaRDD, z, 50);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z, int i) {
        return javaRDD.mapPartitionsToDouble(new ScoreExamplesFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z) {
        return scoreExamples(javaPairRDD, z, 50);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z, int i) {
        return javaPairRDD.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD) {
        return evaluate(javaRDD, null);
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD, List<String> list) {
        return evaluate(javaRDD, list, 50);
    }

    private void update(int i, long j) {
        Environment buildEnvironment = EnvironmentUtils.buildEnvironment();
        buildEnvironment.setNumCores(i);
        buildEnvironment.setAvailableMemory(j);
        Heartbeat.getInstance().reportEvent(Event.SPARK, buildEnvironment, ModelSerializer.taskByModel(this.network));
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD, List<String> list, int i) {
        return (Evaluation) javaRDD.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), i, list == null ? null : this.sc.broadcast(list))).reduce(new EvaluationReduceFunction());
    }
}
