package org.deeplearning4j.spark.impl.graph;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.deeplearning4j.core.loader.DataSetLoader;
import org.deeplearning4j.core.loader.MultiDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.core.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.SparkListenable;
import org.deeplearning4j.spark.impl.common.reduce.LongDoubleReduceFunction;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction;
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSPathsFlatMapFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ArrayPairToPair;
import org.deeplearning4j.spark.impl.graph.scoring.GraphFeedForwardWithKeyFunction;
import org.deeplearning4j.spark.impl.graph.scoring.PairToArrayPair;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreFlatMapFunctionCGDataSet;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreFlatMapFunctionCGMultiDataSet;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateAggregateFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateFlatMapFunction;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/impl/graph/SparkComputationGraph.class */
public class SparkComputationGraph extends SparkListenable {
    private static final Logger log = LoggerFactory.getLogger(SparkComputationGraph.class);
    public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32;
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64;
    public static final int DEFAULT_EVAL_WORKERS = 4;
    private transient JavaSparkContext sc;
    private ComputationGraphConfiguration conf;
    private ComputationGraph network;
    private double lastScore;
    private int defaultEvaluationWorkers;
    private transient AtomicInteger iterationsCount;

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraph computationGraph, TrainingMaster trainingMaster) {
        this(new JavaSparkContext(sparkContext), computationGraph, trainingMaster);
    }

    public SparkComputationGraph(JavaSparkContext javaSparkContext, ComputationGraph computationGraph, TrainingMaster trainingMaster) {
        this.defaultEvaluationWorkers = 4;
        this.iterationsCount = new AtomicInteger(0);
        this.sc = javaSparkContext;
        this.trainingMaster = trainingMaster;
        this.conf = computationGraph.getConfiguration().clone();
        this.network = computationGraph;
        this.network.init();
        SparkUtils.checkKryoConfiguration(javaSparkContext, log);
    }

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraphConfiguration computationGraphConfiguration, TrainingMaster trainingMaster) {
        this(new JavaSparkContext(sparkContext), computationGraphConfiguration, trainingMaster);
    }

    public SparkComputationGraph(JavaSparkContext javaSparkContext, ComputationGraphConfiguration computationGraphConfiguration, TrainingMaster trainingMaster) {
        this.defaultEvaluationWorkers = 4;
        this.iterationsCount = new AtomicInteger(0);
        this.sc = javaSparkContext;
        this.trainingMaster = trainingMaster;
        this.conf = computationGraphConfiguration.clone();
        this.network = new ComputationGraph(computationGraphConfiguration);
        this.network.init();
        SparkUtils.checkKryoConfiguration(javaSparkContext, log);
    }

    public JavaSparkContext getSparkContext() {
        return this.sc;
    }

    public void setCollectTrainingStats(boolean z) {
        this.trainingMaster.setCollectTrainingStats(z);
    }

    public SparkTrainingStats getSparkTrainingStats() {
        return this.trainingMaster.getTrainingStats();
    }

    public ComputationGraph getNetwork() {
        return this.network;
    }

    public TrainingMaster getTrainingMaster() {
        return this.trainingMaster;
    }

    public void setNetwork(ComputationGraph computationGraph) {
        this.network = computationGraph;
    }

    public int getDefaultEvaluationWorkers() {
        return this.defaultEvaluationWorkers;
    }

    public void setDefaultEvaluationWorkers(int i) {
        Preconditions.checkArgument(i > 0, "Number of workers must be > 0: got %s", i);
        this.defaultEvaluationWorkers = i;
    }

    public ComputationGraph fit(RDD<DataSet> rdd) {
        return fit(rdd.toJavaRDD());
    }

    public ComputationGraph fit(JavaRDD<DataSet> javaRDD) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        this.trainingMaster.executeTraining(this, javaRDD);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fit(String str) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        try {
            return fitPaths(SparkUtils.listPaths(this.sc, str));
        } catch (IOException e) {
            throw new RuntimeException("Error listing paths in directory", e);
        }
    }

    @Deprecated
    public ComputationGraph fit(String str, int i) {
        return fit(str);
    }

    public ComputationGraph fitPaths(JavaRDD<String> javaRDD) {
        return fitPaths(javaRDD, (DataSetLoader) new SerializedDataSetLoader());
    }

    public ComputationGraph fitPaths(JavaRDD<String> javaRDD, DataSetLoader dataSetLoader) {
        this.trainingMaster.executeTrainingPaths(null, this, javaRDD, dataSetLoader, null);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fitMultiDataSet(RDD<MultiDataSet> rdd) {
        return fitMultiDataSet(rdd.toJavaRDD());
    }

    public ComputationGraph fitMultiDataSet(JavaRDD<MultiDataSet> javaRDD) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        this.trainingMaster.executeTrainingMDS(this, javaRDD);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fitMultiDataSet(String str) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        try {
            return fitPathsMultiDataSet(SparkUtils.listPaths(this.sc, str));
        } catch (IOException e) {
            throw new RuntimeException("Error listing paths in directory", e);
        }
    }

    public ComputationGraph fitPathsMultiDataSet(JavaRDD<String> javaRDD) {
        return fitPaths(javaRDD, (MultiDataSetLoader) new SerializedMultiDataSetLoader());
    }

    public ComputationGraph fitPaths(JavaRDD<String> javaRDD, MultiDataSetLoader multiDataSetLoader) {
        this.trainingMaster.executeTrainingPaths(null, this, javaRDD, null, multiDataSetLoader);
        this.network.incrementEpochCount();
        return this.network;
    }

    @Deprecated
    public ComputationGraph fitMultiDataSet(String str, int i) {
        return fitMultiDataSet(str);
    }

    public double getScore() {
        return this.lastScore;
    }

    public void setScore(double d) {
        this.lastScore = d;
    }

    public double calculateScore(JavaRDD<DataSet> javaRDD, boolean z) {
        return calculateScore(javaRDD, z, 64);
    }

    public double calculateScore(JavaRDD<DataSet> javaRDD, boolean z, int i) {
        Tuple2 tuple2 = (Tuple2) javaRDD.mapPartitions(new ScoreFlatMapFunctionCGDataSet(this.conf.toJson(), this.sc.broadcast(this.network.params()), i)).reduce(new LongDoubleReduceFunction());
        return z ? ((Double) tuple2._2()).doubleValue() / ((Long) tuple2._1()).longValue() : ((Double) tuple2._2()).doubleValue();
    }

    public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> javaRDD, boolean z) {
        return calculateScoreMultiDataSet(javaRDD, z, 64);
    }

    public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> javaRDD, boolean z, int i) {
        Tuple2 tuple2 = (Tuple2) javaRDD.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(this.conf.toJson(), this.sc.broadcast(this.network.params()), i)).reduce(new LongDoubleReduceFunction());
        return z ? ((Double) tuple2._2()).doubleValue() / ((Long) tuple2._1()).longValue() : ((Double) tuple2._2()).doubleValue();
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z) {
        return scoreExamplesMultiDataSet(javaRDD.map(new DataSetToMultiDataSetFn()), z);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z, int i) {
        return scoreExamplesMultiDataSet(javaRDD.map(new DataSetToMultiDataSetFn()), z, i);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z) {
        return scoreExamplesMultiDataSet(javaPairRDD.mapToPair(new PairDataSetToMultiDataSetFn()), z, 64);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z, int i) {
        return scoreExamplesMultiDataSet(javaPairRDD.mapToPair(new PairDataSetToMultiDataSetFn()), z, i);
    }

    public JavaDoubleRDD scoreExamplesMultiDataSet(JavaRDD<MultiDataSet> javaRDD, boolean z) {
        return scoreExamplesMultiDataSet(javaRDD, z, 64);
    }

    public JavaDoubleRDD scoreExamplesMultiDataSet(JavaRDD<MultiDataSet> javaRDD, boolean z, int i) {
        return javaRDD.mapPartitionsToDouble(new ScoreExamplesFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesMultiDataSet(JavaPairRDD<K, MultiDataSet> javaPairRDD, boolean z) {
        return scoreExamplesMultiDataSet(javaPairRDD, z, 64);
    }

    public <K> JavaPairRDD<K, INDArray> feedForwardWithKeySingle(JavaPairRDD<K, INDArray> javaPairRDD, int i) {
        if (this.network.getNumInputArrays() == 1 && this.network.getNumOutputArrays() == 1) {
            return feedForwardWithKey(javaPairRDD.mapToPair(new PairToArrayPair()), i).mapToPair(new ArrayPairToPair());
        }
        throw new IllegalStateException("Cannot use this method with computation graphs with more than 1 input or output ( has: " + this.network.getNumInputArrays() + " inputs, " + this.network.getNumOutputArrays() + " outputs");
    }

    public <K> JavaPairRDD<K, INDArray[]> feedForwardWithKey(JavaPairRDD<K, INDArray[]> javaPairRDD, int i) {
        return javaPairRDD.mapPartitionsToPair(new GraphFeedForwardWithKeyFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), i));
    }

    private void update(int i, long j) {
        Environment buildEnvironment = EnvironmentUtils.buildEnvironment();
        buildEnvironment.setNumCores(i);
        buildEnvironment.setAvailableMemory(j);
        Heartbeat.getInstance().reportEvent(Event.SPARK, buildEnvironment, ModelSerializer.taskByModel(this.network));
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesMultiDataSet(JavaPairRDD<K, MultiDataSet> javaPairRDD, boolean z, int i) {
        return javaPairRDD.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public Evaluation evaluate(String str, DataSetLoader dataSetLoader) {
        try {
            return doEvaluation(SparkUtils.listPaths(this.sc, str), 4, 64, dataSetLoader, (MultiDataSetLoader) null, new Evaluation())[0];
        } catch (IOException e) {
            throw new RuntimeException("Error listing files for evaluation of files at path: " + str, e);
        }
    }

    public Evaluation evaluate(String str, MultiDataSetLoader multiDataSetLoader) {
        try {
            return doEvaluation(SparkUtils.listPaths(this.sc, str), 4, 64, null, multiDataSetLoader, new Evaluation())[0];
        } catch (IOException e) {
            throw new RuntimeException("Error listing files for evaluation of files at path: " + str, e);
        }
    }

    public <T extends Evaluation> T evaluate(RDD<DataSet> rdd) {
        return (T) evaluate(rdd.toJavaRDD());
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> javaRDD) {
        return (T) evaluate(javaRDD, (List<String>) null);
    }

    public <T extends Evaluation> T evaluate(RDD<DataSet> rdd, List<String> list) {
        return (T) evaluate(rdd.toJavaRDD(), list);
    }

    public <T extends RegressionEvaluation> T evaluateRegression(JavaRDD<DataSet> javaRDD) {
        return (T) evaluateRegression(javaRDD, 64);
    }

    public <T extends RegressionEvaluation> T evaluateRegression(JavaRDD<DataSet> javaRDD, int i) {
        return doEvaluation(javaRDD, (JavaRDD<DataSet>) new org.deeplearning4j.eval.RegressionEvaluation(((FeedForwardLayer) this.network.getOutputLayer(0).conf().getLayer()).getNOut()), i);
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> javaRDD, List<String> list) {
        return (T) evaluate(javaRDD, list, 64);
    }

    public <T extends ROC> T evaluateROC(JavaRDD<DataSet> javaRDD) {
        return (T) evaluateROC(javaRDD, 32, 64);
    }

    public <T extends ROC> T evaluateROC(JavaRDD<DataSet> javaRDD, int i, int i2) {
        return doEvaluation(javaRDD, (JavaRDD<DataSet>) new org.deeplearning4j.eval.ROC(i), i2);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(JavaRDD<DataSet> javaRDD) {
        return (T) evaluateROCMultiClass(javaRDD, 32, 64);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(JavaRDD<DataSet> javaRDD, int i, int i2) {
        return doEvaluation(javaRDD, (JavaRDD<DataSet>) new org.deeplearning4j.eval.ROCMultiClass(i), i2);
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> javaRDD, List<String> list, int i) {
        T doEvaluation = doEvaluation(javaRDD, (JavaRDD<DataSet>) new org.deeplearning4j.eval.Evaluation(), i);
        if (list != null) {
            doEvaluation.setLabelsList(list);
        }
        return doEvaluation;
    }

    public <T extends Evaluation> T evaluateMDS(JavaRDD<MultiDataSet> javaRDD) {
        return (T) evaluateMDS(javaRDD, 64);
    }

    public <T extends Evaluation> T evaluateMDS(JavaRDD<MultiDataSet> javaRDD, int i) {
        return ((org.deeplearning4j.eval.Evaluation[]) doEvaluationMDS(javaRDD, i, new org.deeplearning4j.eval.Evaluation()))[0];
    }

    public <T extends RegressionEvaluation> T evaluateRegressionMDS(JavaRDD<MultiDataSet> javaRDD) {
        return (T) evaluateRegressionMDS(javaRDD, 64);
    }

    public <T extends RegressionEvaluation> T evaluateRegressionMDS(JavaRDD<MultiDataSet> javaRDD, int i) {
        return ((org.deeplearning4j.eval.RegressionEvaluation[]) doEvaluationMDS(javaRDD, i, new org.deeplearning4j.eval.RegressionEvaluation()))[0];
    }

    public ROC evaluateROCMDS(JavaRDD<MultiDataSet> javaRDD) {
        return evaluateROCMDS(javaRDD, 32, 64);
    }

    public <T extends ROC> T evaluateROCMDS(JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        return ((org.deeplearning4j.eval.ROC[]) doEvaluationMDS(javaRDD, i2, new org.deeplearning4j.eval.ROC(i)))[0];
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <T extends IEvaluation> T doEvaluation(JavaRDD<DataSet> javaRDD, T t, int i) {
        return (T) doEvaluation(javaRDD, i, t)[0];
    }

    public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> javaRDD, int i, T... tArr) {
        return (T[]) doEvaluation(javaRDD, getDefaultEvaluationWorkers(), i, tArr);
    }

    public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> javaRDD, int i, int i2, T... tArr) {
        return (T[]) ((IEvaluation[]) javaRDD.mapPartitions(new IEvaluateFlatMapFunction(true, this.sc.broadcast(this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), i, i2, tArr)).treeAggregate((Object) null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction()));
    }

    public <T extends IEvaluation> T[] doEvaluationMDS(JavaRDD<MultiDataSet> javaRDD, int i, T... tArr) {
        return (T[]) doEvaluationMDS(javaRDD, getDefaultEvaluationWorkers(), i, tArr);
    }

    public <T extends IEvaluation> T[] doEvaluationMDS(JavaRDD<MultiDataSet> javaRDD, int i, int i2, T... tArr) {
        Preconditions.checkArgument(i > 0, "Invalid number of evaulation workers: require at least 1 - got %s", i);
        return (T[]) ((IEvaluation[]) javaRDD.mapPartitions(new IEvaluateMDSFlatMapFunction(this.sc.broadcast(this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), i, i2, tArr)).treeAggregate((Object) null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction()));
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> javaRDD, DataSetLoader dataSetLoader, IEvaluation... iEvaluationArr) {
        return doEvaluation(javaRDD, 4, 64, dataSetLoader, iEvaluationArr);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> javaRDD, int i, int i2, DataSetLoader dataSetLoader, IEvaluation... iEvaluationArr) {
        return doEvaluation(javaRDD, i, i2, dataSetLoader, null, iEvaluationArr);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> javaRDD, MultiDataSetLoader multiDataSetLoader, IEvaluation... iEvaluationArr) {
        return doEvaluation(javaRDD, 4, 64, null, multiDataSetLoader, iEvaluationArr);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> javaRDD, int i, int i2, MultiDataSetLoader multiDataSetLoader, IEvaluation... iEvaluationArr) {
        return doEvaluation(javaRDD, i, i2, null, multiDataSetLoader, iEvaluationArr);
    }

    protected IEvaluation[] doEvaluation(JavaRDD<String> javaRDD, int i, int i2, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader, IEvaluation... iEvaluationArr) {
        IEvaluateMDSPathsFlatMapFunction iEvaluateMDSPathsFlatMapFunction = new IEvaluateMDSPathsFlatMapFunction(this.sc.broadcast(this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), i, i2, dataSetLoader, multiDataSetLoader, BroadcastHadoopConfigHolder.get(this.sc), iEvaluationArr);
        Preconditions.checkArgument(i > 0, "Invalid number of evaulation workers: require at least 1 - got %s", i);
        return (IEvaluation[]) javaRDD.mapPartitions(iEvaluateMDSPathsFlatMapFunction).treeAggregate((Object) null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction());
    }
}
