package org.deeplearning4j.spark.ml.nn;

import org.apache.spark.Accumulator;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Iterator;
import scala.collection.JavaConversions$;
import scala.collection.immutable.List$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;

/* JADX INFO: Add missing generic type declarations: [RowType] */
/* compiled from: TrainingStrategy.scala */
/* loaded from: input_file:org/deeplearning4j/spark/ml/nn/ParameterAveragingTrainingStrategy$$anonfun$train$1$$anonfun$apply$mcVI$sp$1.class */
public class ParameterAveragingTrainingStrategy$$anonfun$train$1$$anonfun$apply$mcVI$sp$1<RowType> extends AbstractFunction1<Iterator<RowType>, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ ParameterAveragingTrainingStrategy$$anonfun$train$1 $outer;
    private final Broadcast broadcastedParams$1;
    private final Accumulator accumulatedParams$1;

    public final void apply(Iterator<RowType> iterator) {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(this.$outer.confJson$1));
        multiLayerNetwork.init();
        multiLayerNetwork.setListeners(JavaConversions$.MODULE$.seqAsJavaList(List$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new ScoreIterationListener[]{new ScoreIterationListener(1)}))));
        multiLayerNetwork.setParams((INDArray) this.broadcastedParams$1.value());
        this.$outer.partitionTrainer$1.apply(multiLayerNetwork, iterator);
        this.accumulatedParams$1.$plus$eq(multiLayerNetwork.params());
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((Iterator) obj);
        return BoxedUnit.UNIT;
    }

    /* JADX WARN: Incorrect inner types in method signature: (Lorg/deeplearning4j/spark/ml/nn/ParameterAveragingTrainingStrategy<TRowType;>.$anonfun$train$1;)V */
    public ParameterAveragingTrainingStrategy$$anonfun$train$1$$anonfun$apply$mcVI$sp$1(ParameterAveragingTrainingStrategy$$anonfun$train$1 parameterAveragingTrainingStrategy$$anonfun$train$1, Broadcast broadcast, Accumulator accumulator) {
        if (parameterAveragingTrainingStrategy$$anonfun$train$1 == null) {
            throw new NullPointerException();
        }
        this.$outer = parameterAveragingTrainingStrategy$$anonfun$train$1;
        this.broadcastedParams$1 = broadcast;
        this.accumulatedParams$1 = accumulator;
    }
}
