package org.apache.spark.ml.odkl;

import com.github.fommil.netlib.BLAS;
import org.apache.spark.ml.odkl.HasNetlibBlas;
import org.apache.spark.ml.odkl.ModelWithSummary;
import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.apache.spark.mllib.linalg.DenseMatrix$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.odkl.MatrixUtils$;
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

/* compiled from: DSVRGD.scala */
/* loaded from: input_file:org/apache/spark/ml/odkl/DSVRGD$.class */
public final class DSVRGD$ implements Serializable, HasNetlibBlas {
    public static final DSVRGD$ MODULE$ = null;
    private final ModelWithSummary.Block LossHistory;
    private final ModelWithSummary.Block WeightDiffHistory;
    private final ModelWithSummary.Block WeightNormHistory;

    static {
        new DSVRGD$();
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS f2jBLAS() {
        return HasNetlibBlas.Cclass.f2jBLAS(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS blas() {
        return HasNetlibBlas.Cclass.blas(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void dscal(double d, double[] dArr) {
        HasNetlibBlas.Cclass.dscal(this, d, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.axpy(this, d, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, Vector vector, double[] dArr) {
        HasNetlibBlas.Cclass.axpy(this, d, vector, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void copy(double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.copy(this, dArr, dArr2);
    }

    public ModelWithSummary.Block LossHistory() {
        return this.LossHistory;
    }

    public ModelWithSummary.Block WeightDiffHistory() {
        return this.WeightDiffHistory;
    }

    public ModelWithSummary.Block WeightNormHistory() {
        return this.WeightNormHistory;
    }

    public void linear(Matrix matrix, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, DenseMatrix denseMatrix4, DenseVector denseVector) {
        BLAS$.MODULE$.gemm(1.0d, matrix, denseMatrix, 0.0d, denseMatrix4);
        axpy(-1.0d, denseMatrix2.values(), denseMatrix4.values());
        double numCols = 1.0d / denseMatrix.numCols();
        BLAS$.MODULE$.gemm(numCols, denseMatrix4, denseMatrix.transpose(), 0.0d, denseMatrix3);
        denseMatrix4.foreachActive(new DSVRGD$$anonfun$linear$1(denseVector, numCols));
    }

    public void logistic(Matrix matrix, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseMatrix denseMatrix3, DenseMatrix denseMatrix4, DenseVector denseVector) {
        BLAS$.MODULE$.gemm(-1.0d, matrix, denseMatrix, 0.0d, denseMatrix4);
        double numCols = 1.0d / denseMatrix.numCols();
        MatrixUtils$.MODULE$.applyNonZeros(denseMatrix2, denseMatrix4, new DSVRGD$$anonfun$logistic$1(denseVector, numCols));
        BLAS$.MODULE$.gemm(numCols, denseMatrix4, denseMatrix.transpose(), 0.0d, denseMatrix3);
    }

    public Matrix logisticInitialization(DataFrame dataFrame, int i, int i2) {
        RDD map = dataFrame.map(new DSVRGD$$anonfun$18(), ClassTag$.MODULE$.apply(Vector.class));
        MultivariateOnlineSummarizer multivariateOnlineSummarizer = new MultivariateOnlineSummarizer();
        return MatrixUtils$.MODULE$.transformDense(DenseMatrix$.MODULE$.zeros(i, i2), new DSVRGD$$anonfun$logisticInitialization$1(i2, (MultivariateOnlineSummarizer) map.treeAggregate(multivariateOnlineSummarizer, new DSVRGD$$anonfun$19(), new DSVRGD$$anonfun$20(), map.treeAggregate$default$4(multivariateOnlineSummarizer), ClassTag$.MODULE$.apply(MultivariateOnlineSummarizer.class))));
    }

    public double linearWeightsDistance(Matrix matrix, DenseMatrix denseMatrix, int i) {
        DoubleRef doubleRef = new DoubleRef(0.0d);
        DoubleRef doubleRef2 = new DoubleRef(0.0d);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), denseMatrix.numCols()).foreach$mVc$sp(new DSVRGD$$anonfun$linearWeightsDistance$1(matrix, denseMatrix, i, doubleRef, doubleRef2));
        return Math.sqrt(doubleRef.elem) / Math.sqrt(doubleRef2.elem);
    }

    public double logisticWeightsDistance(Matrix matrix, DenseMatrix denseMatrix, int i) {
        DoubleRef doubleRef = new DoubleRef(0.0d);
        DoubleRef doubleRef2 = new DoubleRef(0.0d);
        DoubleRef doubleRef3 = new DoubleRef(0.0d);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), denseMatrix.numCols()).foreach$mVc$sp(new DSVRGD$$anonfun$logisticWeightsDistance$1(matrix, denseMatrix, i, doubleRef, doubleRef2, doubleRef3));
        if (doubleRef2.elem * doubleRef3.elem > 0) {
            return 1 - (doubleRef.elem / Math.sqrt(doubleRef2.elem * doubleRef3.elem));
        }
        return 2.0d;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private DSVRGD$() {
        MODULE$ = this;
        HasNetlibBlas.Cclass.$init$(this);
        this.LossHistory = new ModelWithSummary.Block("lossHistory");
        this.WeightDiffHistory = new ModelWithSummary.Block("weightDiffHistory");
        this.WeightNormHistory = new ModelWithSummary.Block("weightNormHistory");
    }
}
