package org.deeplearning4j.spark.impl.multilayer;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.impl.common.Adder;
import org.deeplearning4j.spark.impl.common.BestScoreAccumulator;
import org.deeplearning4j.spark.impl.common.gradient.GradientAdder;
import org.deeplearning4j.spark.impl.multilayer.gradientaccum.GradientAccumFlatMap;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.class */
public class SparkDl4jMultiLayer implements Serializable {
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;
    private Broadcast<INDArray> params;
    private boolean averageEachIteration;
    public static final String AVERAGE_EACH_ITERATION = "org.deeplearning4j.spark.iteration.average";
    public static final String ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.accumgrad";
    public static final String DIVIDE_ACCUM_GRADIENT = "org.deeplearning4j.spark.iteration.dividegrad";
    private Accumulator<Double> best_score_acc;
    private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class);

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork multiLayerNetwork) {
        this.averageEachIteration = false;
        this.best_score_acc = null;
        this.sparkContext = sparkContext;
        this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.network = multiLayerNetwork;
        this.conf = this.network.getLayerWiseConfigurations().clone();
        this.sc = new JavaSparkContext(this.sparkContext);
        this.params = this.sc.broadcast(multiLayerNetwork.params());
        this.best_score_acc = BestScoreAccumulator.create(sparkContext);
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this.averageEachIteration = false;
        this.best_score_acc = null;
        this.sparkContext = sparkContext;
        this.conf = multiLayerConfiguration.clone();
        this.averageEachIteration = sparkContext.conf().getBoolean(AVERAGE_EACH_ITERATION, false);
        this.sc = new JavaSparkContext(this.sparkContext);
        this.best_score_acc = BestScoreAccumulator.create(sparkContext);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this(javaSparkContext.sc(), multiLayerConfiguration);
    }

    public MultiLayerNetwork fit(String str, int i, RecordReader recordReader) {
        return fitDataSet(this.sc.textFile(str).map(new RecordReaderFunction(recordReader, i, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut())));
    }

    public MultiLayerNetwork getNetwork() {
        return this.network;
    }

    public void setNetwork(MultiLayerNetwork multiLayerNetwork) {
        this.network = multiLayerNetwork;
    }

    public Matrix predict(Matrix matrix) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(matrix)));
    }

    public Vector predict(Vector vector) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(vector)));
    }

    public MultiLayerNetwork fit(JavaRDD<LabeledPoint> javaRDD, int i) {
        return fitDataSet(MLLibUtil.fromLabeledPoint(javaRDD, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut(), i));
    }

    public MultiLayerNetwork fit(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD) {
        return fitDataSet(MLLibUtil.fromLabeledPoint(javaSparkContext, javaRDD, this.conf.getConf(this.conf.getConfs().size() - 1).getLayer().getNOut()));
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> javaRDD) {
        int numIterations = this.conf.getConf(0).getNumIterations();
        log.info("Running distributed training averaging each iteration " + this.averageEachIteration + " and " + javaRDD.partitions().size() + " partitions");
        if (this.averageEachIteration) {
            Iterator it = this.conf.getConfs().iterator();
            while (it.hasNext()) {
                ((NeuralNetConfiguration) it.next()).setNumIterations(1);
            }
            MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(this.conf);
            multiLayerNetwork.init();
            this.params = this.sc.broadcast(multiLayerNetwork.params());
            for (int i = 0; i < numIterations; i++) {
                runIteration(javaRDD);
            }
        } else {
            runIteration(javaRDD);
        }
        return this.network;
    }

    private void runIteration(JavaRDD<DataSet> javaRDD) {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(this.conf);
        multiLayerNetwork.init();
        INDArray params = multiLayerNetwork.params();
        this.params = this.sc.broadcast(params);
        log.info("Broadcasting initial parameters of length " + params.length());
        int numParams = multiLayerNetwork.numParams();
        if (params.length() != numParams) {
            throw new IllegalStateException("Number of params " + numParams + " was not equal to " + params.length());
        }
        if (!this.sc.getConf().getBoolean(ACCUM_GRADIENT, false)) {
            JavaRDD cache = javaRDD.mapPartitions(new IterativeReduceFlatMap(this.conf.toJson(), this.params, this.best_score_acc), true).cache();
            log.info("Ran iterative reduce...averaging results now.");
            Adder adder = new Adder(params.length());
            cache.foreach(adder);
            INDArray iNDArray = (INDArray) adder.getAccumulator().value();
            log.info("Accumulated parameters");
            iNDArray.divi(Integer.valueOf(javaRDD.partitions().size()));
            log.info("Divided by partitions");
            multiLayerNetwork.setParameters(iNDArray);
            log.info("Set parameters");
            this.network = multiLayerNetwork;
            return;
        }
        JavaRDD cache2 = javaRDD.mapPartitions(new GradientAccumFlatMap(this.conf.toJson(), this.params), true).cache();
        log.info("Ran iterative reduce...averaging results now.");
        GradientAdder gradientAdder = new GradientAdder(params.length());
        cache2.foreach(gradientAdder);
        INDArray iNDArray2 = (INDArray) gradientAdder.getAccumulator().value();
        if (this.sc.getConf().getBoolean(DIVIDE_ACCUM_GRADIENT, false)) {
            iNDArray2.divi(Integer.valueOf(cache2.partitions().size()));
        }
        log.info("Accumulated parameters");
        log.info("Summed gradients.");
        multiLayerNetwork.setParameters(multiLayerNetwork.params().addi(iNDArray2));
        log.info("Set parameters");
        this.network = multiLayerNetwork;
    }

    public static MultiLayerNetwork train(JavaRDD<LabeledPoint> javaRDD, MultiLayerConfiguration multiLayerConfiguration) {
        return new SparkDl4jMultiLayer(javaRDD.context(), multiLayerConfiguration).fit(new JavaSparkContext(javaRDD.context()), javaRDD);
    }
}
