package org.deeplearning4j.spark.impl.multilayer;

import java.io.Serializable;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.canova.api.records.reader.RecordReader;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.canova.RDDMiniBatches;
import org.deeplearning4j.spark.canova.RecordReaderFunction;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.class */
public class SparkDl4jMultiLayer implements Serializable {
    private transient SparkContext sparkContext;
    private transient JavaSparkContext sc;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork multiLayerNetwork) {
        this.sparkContext = sparkContext;
        this.conf = this.conf.clone();
        this.sc = new JavaSparkContext(this.sparkContext);
        this.network = multiLayerNetwork;
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this.sparkContext = sparkContext;
        this.conf = multiLayerConfiguration.clone();
        this.sc = new JavaSparkContext(this.sparkContext);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerConfiguration multiLayerConfiguration) {
        this.sc = javaSparkContext;
        this.conf = multiLayerConfiguration.clone();
    }

    public MultiLayerNetwork fit(String str, int i, RecordReader recordReader) {
        return fitDataSet(this.sc.textFile(str).map(new RecordReaderFunction(recordReader, i, this.conf.getConf(this.conf.getConfs().size() - 1).getnOut())));
    }

    public Matrix predict(Matrix matrix) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(matrix)));
    }

    public Vector predict(Vector vector) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(vector)));
    }

    public MultiLayerNetwork fit(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD) {
        return fitDataSet(MLLibUtil.fromLabeledPoint(javaSparkContext, javaRDD, this.conf.getConf(this.conf.getConfs().size() - 1).getnOut()));
    }

    public MultiLayerNetwork fitDataSet(JavaRDD<DataSet> javaRDD) {
        JavaRDD<DataSet> miniBatchesJava = new RDDMiniBatches(this.conf.getConf(0).getBatchSize(), javaRDD).miniBatchesJava();
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(this.conf);
        multiLayerNetwork.init();
        INDArray params = multiLayerNetwork.params();
        int numParams = multiLayerNetwork.numParams();
        if (params.length() != numParams) {
            throw new IllegalStateException("Number of params " + numParams + " was not equal to " + params.length());
        }
        multiLayerNetwork.setParameters(((INDArray) miniBatchesJava.map(new DL4jWorker(this.conf.toJson(), params)).reduce(new Function2<INDArray, INDArray, INDArray>() { // from class: org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer.1
            public INDArray call(INDArray iNDArray, INDArray iNDArray2) throws Exception {
                return iNDArray.addi(iNDArray2);
            }
        })).divi(Long.valueOf(miniBatchesJava.count())));
        this.network = multiLayerNetwork;
        return multiLayerNetwork;
    }

    public static MultiLayerNetwork train(JavaRDD<LabeledPoint> javaRDD, MultiLayerConfiguration multiLayerConfiguration) {
        return new SparkDl4jMultiLayer(javaRDD.context(), multiLayerConfiguration).fit(new JavaSparkContext(javaRDD.context()), javaRDD);
    }
}
