package org.apache.spark.ml.odkl;

import com.github.fommil.netlib.BLAS;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.AttributeGroup$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.odkl.HasNetlibBlas;
import org.apache.spark.ml.odkl.MatrixLBFGS;
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function3;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.collection.parallel.ParIterableLike;
import scala.collection.parallel.ThreadPoolTaskSupport;
import scala.collection.parallel.mutable.ParArray;
import scala.collection.parallel.mutable.ParArray$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: MatrixLBFGS.scala */
/* loaded from: input_file:org/apache/spark/ml/odkl/MatrixLBFGS$.class */
public final class MatrixLBFGS$ implements Logging, HasNetlibBlas, Serializable {
    public static final MatrixLBFGS$ MODULE$ = null;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new MatrixLBFGS$();
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS f2jBLAS() {
        return HasNetlibBlas.Cclass.f2jBLAS(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public BLAS blas() {
        return HasNetlibBlas.Cclass.blas(this);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void dscal(double d, double[] dArr) {
        HasNetlibBlas.Cclass.dscal(this, d, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.axpy(this, d, dArr, dArr2);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void axpy(double d, Vector vector, double[] dArr) {
        HasNetlibBlas.Cclass.axpy(this, d, vector, dArr);
    }

    @Override // org.apache.spark.ml.odkl.HasNetlibBlas
    public void copy(double[] dArr, double[] dArr2) {
        HasNetlibBlas.Cclass.copy(this, dArr, dArr2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    public String logName() {
        return Logging.class.logName(this);
    }

    public Logger log() {
        return Logging.class.log(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.class.logInfo(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.class.logDebug(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.class.logTrace(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.class.logWarning(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.class.logError(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.class.logInfo(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.class.logDebug(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.class.logTrace(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.class.logWarning(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.class.logError(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.class.initializeLogIfNecessary(this, z);
    }

    public void computeGradient(Vector vector, Vector vector2, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseVector denseVector) {
        computeGradientMatrix(vector.toArray(), vector2.toArray(), denseMatrix, denseMatrix2, denseVector, new double[vector2.size()], 1);
    }

    public void computeGradientMatrix(double[] dArr, double[] dArr2, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, DenseVector denseVector, double[] dArr3, int i) {
        int numRows = denseMatrix.numRows();
        int numCols = denseMatrix.numCols();
        org$apache$spark$ml$odkl$MatrixLBFGS$$gemm(-1.0d, 0.0d, denseMatrix.values(), dArr, dArr3, denseMatrix.numRows(), denseMatrix.numCols(), numCols, i, denseMatrix.isTransposed(), false);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), numRows * i).foreach$mVc$sp(new MatrixLBFGS$$anonfun$computeGradientMatrix$1(dArr2, denseVector, dArr3, numRows));
        org$apache$spark$ml$odkl$MatrixLBFGS$$gemm(1.0d, 1.0d, dArr3, dArr, denseMatrix2.values(), numRows, i, i, numCols, false, true);
    }

    public <T> Tuple2<DenseMatrix, DenseVector> computeGradientAndLoss(RDD<Tuple2<Vector, T>> rdd, DenseMatrix denseMatrix, int i, Function3<Object, T, double[], BoxedUnit> function3) {
        Broadcast broadcast = rdd.sparkContext().broadcast(denseMatrix, ClassTag$.MODULE$.apply(DenseMatrix.class));
        try {
            RDD mapPartitions = rdd.mapPartitions(new MatrixLBFGS$$anonfun$5(denseMatrix, i, function3, broadcast, denseMatrix.numRows(), denseMatrix.numCols()), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(Tuple2.class));
            return (Tuple2) mapPartitions.treeReduce(new MatrixLBFGS$$anonfun$computeGradientAndLoss$1(), mapPartitions.treeReduce$default$2());
        } finally {
            broadcast.destroy();
        }
    }

    public <T> int computeGradientAndLoss$default$3() {
        return 10;
    }

    public Map<String, Vector> multiClassLBFGS(Dataset<Row> dataset, String str, String str2, int i, double d, int i2, int i3, double d2, boolean z) {
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str2));
        AttributeGroup fromStructField2 = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str));
        Tuple2<Object, Vector> evaluateMaxRegularization = d2 > ((double) 0) ? evaluateMaxRegularization(dataset, str, str2, z) : new Tuple2<>(BoxesRunTime.boxToLong(dataset.count()), Vectors$.MODULE$.zeros(fromStructField.size()));
        if (evaluateMaxRegularization == null) {
            throw new MatchError(evaluateMaxRegularization);
        }
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToLong(evaluateMaxRegularization._1$mcJ$sp()), (Vector) evaluateMaxRegularization._2());
        long _1$mcJ$sp = tuple2._1$mcJ$sp();
        Vector vector = (Vector) tuple2._2();
        logInfo(new MatrixLBFGS$$anonfun$multiClassLBFGS$1(fromStructField, vector));
        MatrixLBFGS.BatchCostFunction batchCostFunction = new MatrixLBFGS.BatchCostFunction(dataset.select(str, Predef$.MODULE$.wrapRefArray(new String[]{str2})).rdd().map(new MatrixLBFGS$$anonfun$6(), ClassTag$.MODULE$.apply(Tuple2.class)), fromStructField2.size(), fromStructField.size(), _1$mcJ$sp, i3);
        Attribute[] attributeArr = (Attribute[]) fromStructField.attributes().getOrElse(new MatrixLBFGS$$anonfun$7(fromStructField));
        batchCostFunction.reset();
        ParArray par = Predef$.MODULE$.refArrayOps(attributeArr).par();
        ThreadPoolTaskSupport threadPoolTaskSupport = new ThreadPoolTaskSupport(new ThreadPoolExecutor(fromStructField.size(), fromStructField.size(), Long.MAX_VALUE, TimeUnit.DAYS, new ArrayBlockingQueue(fromStructField.size())));
        threadPoolTaskSupport.environment().prestartAllCoreThreads();
        par.tasksupport_$eq(threadPoolTaskSupport);
        logInfo(new MatrixLBFGS$$anonfun$multiClassLBFGS$2());
        MatrixLBFGS.LbfgsState[] lbfgsStateArr = (MatrixLBFGS.LbfgsState[]) ((ParIterableLike) par.map(new MatrixLBFGS$$anonfun$8(i, d, i2, d2, z, fromStructField2, vector, batchCostFunction), ParArray$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(MatrixLBFGS.LbfgsState.class));
        logInfo(new MatrixLBFGS$$anonfun$multiClassLBFGS$3());
        IntRef intRef = new IntRef(1);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        do {
            logInfo(new MatrixLBFGS$$anonfun$multiClassLBFGS$4(intRef));
            batchCostFunction.reset();
            atomicInteger.set(0);
            CountDownLatch countDownLatch = new CountDownLatch(fromStructField.size());
            Predef$.MODULE$.refArrayOps(lbfgsStateArr).foreach(new MatrixLBFGS$$anonfun$multiClassLBFGS$5(batchCostFunction, threadPoolTaskSupport, atomicInteger, countDownLatch));
            countDownLatch.await();
            logInfo(new MatrixLBFGS$$anonfun$multiClassLBFGS$6(intRef, atomicInteger));
            intRef.elem++;
        } while (atomicInteger.get() > 0);
        threadPoolTaskSupport.environment().shutdown();
        return Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(lbfgsStateArr).map(new MatrixLBFGS$$anonfun$multiClassLBFGS$7(attributeArr), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.conforms());
    }

    public double multiClassLBFGS$default$8() {
        return 0.0d;
    }

    public boolean multiClassLBFGS$default$9() {
        return true;
    }

    public Tuple2<Object, Vector> evaluateMaxRegularization(Dataset<Row> dataset, String str, String str2, boolean z) {
        AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str2));
        AttributeGroup fromStructField = AttributeGroup$.MODULE$.fromStructField(dataset.schema().apply(str));
        RDD<Tuple2<Vector, Vector>> map = dataset.toDF().select(str, Predef$.MODULE$.wrapRefArray(new String[]{str2})).rdd().map(new MatrixLBFGS$$anonfun$9(), ClassTag$.MODULE$.apply(Tuple2.class));
        RDD map2 = map.map(new MatrixLBFGS$$anonfun$10(), ClassTag$.MODULE$.apply(Vector.class));
        RDD mapPartitions = map2.mapPartitions(new MatrixLBFGS$$anonfun$11(), map2.mapPartitions$default$2(), ClassTag$.MODULE$.apply(MultivariateOnlineSummarizer.class));
        MultivariateOnlineSummarizer multivariateOnlineSummarizer = (MultivariateOnlineSummarizer) mapPartitions.treeReduce(new MatrixLBFGS$$anonfun$12(), mapPartitions.treeReduce$default$2());
        return new Tuple2<>(BoxesRunTime.boxToLong(multivariateOnlineSummarizer.count()), evaluateMaxRegularization(map, z, fromStructField.size(), multivariateOnlineSummarizer.mean().asML().toDense(), multivariateOnlineSummarizer.count()));
    }

    public Vector evaluateMaxRegularization(RDD<Tuple2<Vector, Vector>> rdd, boolean z, int i, DenseVector denseVector, long j) {
        int size = denseVector.size();
        RDD mapPartitions = rdd.mapPartitions(new MatrixLBFGS$$anonfun$13(i, denseVector, size, 1.0d / j), rdd.mapPartitions$default$2(), ClassTag$.MODULE$.apply(DenseMatrix.class));
        return Vectors$.MODULE$.dense((double[]) Array$.MODULE$.tabulate(size, new MatrixLBFGS$$anonfun$1((DenseMatrix) mapPartitions.treeReduce(new MatrixLBFGS$$anonfun$14(), mapPartitions.treeReduce$default$2()), z ? i : i - 1), ClassTag$.MODULE$.Double()));
    }

    public void org$apache$spark$ml$odkl$MatrixLBFGS$$gemm(double d, double d2, double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3, int i4, boolean z, boolean z2) {
        blas().dgemm(z ? "T" : "N", z2 ? "T" : "N", i, i4, i2, d, dArr, z ? i2 : i, dArr2, z2 ? i4 : i3, d2, dArr3, i);
    }

    private Object readResolve() {
        return MODULE$;
    }

    private MatrixLBFGS$() {
        MODULE$ = this;
        Logging.class.$init$(this);
        HasNetlibBlas.Cclass.$init$(this);
    }
}
