package org.deeplearning4j.spark.ml.nn;

import org.apache.spark.Accumulator;
import org.apache.spark.SparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import scala.Function2;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;

/* compiled from: TrainingStrategy.scala */
/* loaded from: input_file:org/deeplearning4j/spark/ml/nn/ParameterAveragingTrainingStrategy$$anonfun$train$1.class */
public class ParameterAveragingTrainingStrategy$$anonfun$train$1 extends AbstractFunction1.mcVI.sp implements Serializable {
    public static final long serialVersionUID = 0;
    private final RDD rdd$1;
    public final Function2 partitionTrainer$1;
    private final SparkContext sc$1;
    public final String confJson$1;
    private final ObjectRef networkParams$1;

    public final void apply(int i) {
        apply$mcVI$sp(i);
    }

    public void apply$mcVI$sp(int i) {
        Broadcast broadcast = this.sc$1.broadcast((INDArray) this.networkParams$1.elem, ClassTag$.MODULE$.apply(INDArray.class));
        Accumulator accumulator = this.sc$1.accumulator(Nd4j.zeros(((INDArray) this.networkParams$1.elem).shape()), INDArrayAccumulatorParam$.MODULE$);
        this.rdd$1.foreachPartition(new ParameterAveragingTrainingStrategy$$anonfun$train$1$$anonfun$apply$mcVI$sp$1(this, broadcast, accumulator));
        this.networkParams$1.elem = ((INDArray) accumulator.value()).divi(Predef$.MODULE$.int2Integer(this.rdd$1.partitions().length));
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply(BoxesRunTime.unboxToInt(obj));
        return BoxedUnit.UNIT;
    }

    public ParameterAveragingTrainingStrategy$$anonfun$train$1(ParameterAveragingTrainingStrategy parameterAveragingTrainingStrategy, RDD rdd, Function2 function2, SparkContext sparkContext, String str, ObjectRef objectRef) {
        this.rdd$1 = rdd;
        this.partitionTrainer$1 = function2;
        this.sc$1 = sparkContext;
        this.confJson$1 = str;
        this.networkParams$1 = objectRef;
    }
}
