package org.deeplearning4j.spark.api;

import java.io.Serializable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/* loaded from: input_file:org/deeplearning4j/spark/api/TrainingWorker.class */
public interface TrainingWorker<R extends TrainingResult> extends Serializable {
    void removeHook(TrainingHook trainingHook);

    void addHook(TrainingHook trainingHook);

    MultiLayerNetwork getInitialModel();

    ComputationGraph getInitialModelGraph();

    R processMinibatch(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z);

    R processMinibatch(DataSet dataSet, ComputationGraph computationGraph, boolean z);

    R processMinibatch(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z);

    Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork multiLayerNetwork, boolean z);

    Pair<R, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph computationGraph, boolean z);

    Pair<R, SparkTrainingStats> processMinibatchWithStats(MultiDataSet multiDataSet, ComputationGraph computationGraph, boolean z);

    R getFinalResult(MultiLayerNetwork multiLayerNetwork);

    R getFinalResult(ComputationGraph computationGraph);

    R getFinalResultNoData();

    Pair<R, SparkTrainingStats> getFinalResultNoDataWithStats();

    Pair<R, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork multiLayerNetwork);

    Pair<R, SparkTrainingStats> getFinalResultWithStats(ComputationGraph computationGraph);

    WorkerConfiguration getDataConfiguration();
}
