package org.apache.spark.ml.odkl;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseMatrix$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.odkl.DSVRGD;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Iterator$;
import scala.collection.immutable.Nil$;
import scala.package$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.IntRef;
import scala.util.Random$;

/* compiled from: DSVRGD.scala */
/* loaded from: input_file:org/apache/spark/ml/odkl/DSVRGD$$anonfun$16.class */
public class DSVRGD$$anonfun$16 extends AbstractFunction1<Iterator<Tuple2<Vector, DenseVector>>, Iterator<DSVRGD.DistributedSgdState>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ DSVRGD $outer;
    private final Broadcast weights$2;
    private final Broadcast avgWeights$1;
    private final Broadcast avgGradient$1;
    private final Vector l1regParam$1;
    private final Vector l2regParam$1;
    private final int stepNum$1;
    public final DenseVector labelLearningRates$1;
    private final boolean doRegularizeLast$1;
    public final int batchSize$1;

    public final Iterator<DSVRGD.DistributedSgdState> apply(Iterator<Tuple2<Vector, DenseVector>> iterator) {
        IntRef intRef = new IntRef(0);
        DenseMatrix dense = this.$outer.toDense(this.weights$2);
        DenseVector dense2 = Vectors$.MODULE$.zeros(dense.numRows()).toDense();
        DenseVector dense3 = Vectors$.MODULE$.zeros(dense.numRows()).toDense();
        DenseVector copy = dense3.copy();
        DenseVector copy2 = dense3.copy();
        DenseVector copy3 = dense3.copy();
        DenseMatrix zeros = DenseMatrix$.MODULE$.zeros(dense.numRows(), dense.numCols());
        DenseMatrix copy4 = zeros.copy();
        DenseMatrix copy5 = zeros.copy();
        DenseMatrix copy6 = zeros.copy();
        DenseMatrix copy7 = zeros.copy();
        DenseMatrix copy8 = zeros.copy();
        DenseMatrix copy9 = zeros.copy();
        int numCols = this.doRegularizeLast$1 ? -1 : dense.numCols() - 1;
        DenseMatrix dense4 = this.$outer.toDense(this.avgGradient$1);
        DenseMatrix dense5 = this.$outer.toDense(this.avgWeights$1);
        double[] dArr = new double[this.batchSize$1 * dense.numRows()];
        double[] dArr2 = new double[this.batchSize$1 * dense.numCols()];
        double[] dArr3 = new double[this.batchSize$1 * dense.numRows()];
        IntRef intRef2 = new IntRef(0);
        IntRef intRef3 = new IntRef(0);
        IntRef intRef4 = new IntRef(0);
        DenseMatrix zeros2 = DenseMatrix$.MODULE$.zeros(dense.numRows(), dense.numCols());
        MatrixUtils$.MODULE$.transformDense(zeros2, new DSVRGD$$anonfun$16$$anonfun$apply$17(this));
        Random$.MODULE$.shuffle(iterator, Iterator$.MODULE$.IteratorCanBuildFrom()).withFilter(new DSVRGD$$anonfun$16$$anonfun$apply$18(this)).foreach(new DSVRGD$$anonfun$16$$anonfun$apply$19(this, intRef, dense, dense2, dense3, copy, copy2, copy3, zeros, copy4, copy5, copy6, copy7, copy8, copy9, numCols, dense4, dense5, dArr, dArr2, dArr3, intRef2, intRef3, intRef4, zeros2));
        if (intRef.elem <= 0) {
            return package$.MODULE$.Iterator().apply(Nil$.MODULE$);
        }
        double d = 1.0d / intRef.elem;
        this.$outer.dscal(intRef.elem, dense.values());
        return package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new DSVRGD.DistributedSgdState[]{new DSVRGD.DistributedSgdState(dense, copy4, copy6, dense3, intRef.elem)}));
    }

    public final void org$apache$spark$ml$odkl$DSVRGD$$anonfun$$minibatchStep$1(IntRef intRef, DenseMatrix denseMatrix, DenseVector denseVector, DenseVector denseVector2, DenseVector denseVector3, DenseVector denseVector4, DenseVector denseVector5, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, DenseMatrix denseMatrix4, DenseMatrix denseMatrix5, DenseMatrix denseMatrix6, DenseMatrix denseMatrix7, DenseMatrix denseMatrix8, int i, DenseMatrix denseMatrix9, DenseMatrix denseMatrix10, double[] dArr, double[] dArr2, double[] dArr3, IntRef intRef2, IntRef intRef3, IntRef intRef4, DenseMatrix denseMatrix11) {
        DenseMatrix denseMatrix12 = new DenseMatrix(denseMatrix.numCols(), intRef2.elem, dArr2);
        DenseMatrix denseMatrix13 = new DenseMatrix(denseMatrix.numRows(), intRef2.elem, dArr3);
        DenseMatrix denseMatrix14 = new DenseMatrix(denseMatrix.numRows(), intRef2.elem, dArr);
        this.$outer.fullGradientAndLoss(this.l1regParam$1, this.l2regParam$1, denseMatrix, denseMatrix14, denseVector, denseMatrix2, i, denseMatrix12, denseMatrix13);
        this.$outer.axpyCompensated(denseVector.values(), denseVector2.values(), denseVector3.values(), denseVector5.values(), denseVector4.values());
        this.$outer.axpyCompensated(denseMatrix2.values(), denseMatrix3.values(), denseMatrix4.values(), denseMatrix8.values(), denseMatrix7.values());
        this.$outer.adjust(-1, denseMatrix11, denseMatrix2, denseMatrix);
        if (this.stepNum$1 > 1) {
            this.$outer.fullGradientAndLoss(this.l1regParam$1, this.l2regParam$1, denseMatrix10, denseMatrix14, denseVector, denseMatrix2, i, denseMatrix12, denseMatrix13);
            this.$outer.adjust(1, denseMatrix11, denseMatrix2, denseMatrix);
            this.$outer.adjust(-1, denseMatrix11, denseMatrix9, denseMatrix);
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        this.$outer.axpyCompensated(denseMatrix.values(), denseMatrix5.values(), denseMatrix6.values(), denseMatrix8.values(), denseMatrix7.values());
        intRef.elem++;
        intRef2.elem = 0;
        intRef3.elem = 0;
        intRef4.elem = 0;
    }

    public DSVRGD$$anonfun$16(DSVRGD dsvrgd, Broadcast broadcast, Broadcast broadcast2, Broadcast broadcast3, Vector vector, Vector vector2, int i, DenseVector denseVector, boolean z, int i2) {
        if (dsvrgd == null) {
            throw new NullPointerException();
        }
        this.$outer = dsvrgd;
        this.weights$2 = broadcast;
        this.avgWeights$1 = broadcast2;
        this.avgGradient$1 = broadcast3;
        this.l1regParam$1 = vector;
        this.l2regParam$1 = vector2;
        this.stepNum$1 = i;
        this.labelLearningRates$1 = denseVector;
        this.doRegularizeLast$1 = z;
        this.batchSize$1 = i2;
    }
}
