package org.neo4j.gds.ml.models.logisticregression;

import java.util.function.LongUnaryOperator;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/models/logisticregression/LogisticRegressionTrainer.class */
public final class LogisticRegressionTrainer implements ClassifierTrainer {
    private final LogisticRegressionTrainConfig trainConfig;
    private final int numberOfClasses;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final boolean reduceClassCount;
    private final LogLevel messageLogLevel;
    private final int concurrency;

    public static MemoryEstimation memoryEstimation(boolean z, int i, MemoryRange memoryRange, int i2, LongUnaryOperator longUnaryOperator) {
        return MemoryEstimations.builder("train logistic regression").add("model data", LogisticRegressionData.memoryEstimation(z, i, memoryRange)).add("update weights", Training.memoryEstimation(memoryRange, i)).perGraphDimension("computation graph", (graphDimensions, num) -> {
            return memoryRange.apply(j -> {
                return sizeInBytesOfComputationGraph(z, i2, (int) j, i);
            }).times((int) Math.min(num.intValue(), Math.ceil(longUnaryOperator.applyAsLong(graphDimensions.nodeCount()) / i2)));
        }).build();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static long sizeInBytesOfComputationGraph(boolean z, int i, int i2, int i3) {
        return LogisticRegressionObjective.sizeOfBatchInBytes(z, i, i2, i3);
    }

    public LogisticRegressionTrainer(int i, LogisticRegressionTrainConfig logisticRegressionTrainConfig, int i2, boolean z, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel logLevel) {
        this.concurrency = i;
        this.trainConfig = logisticRegressionTrainConfig;
        this.numberOfClasses = i2;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.reduceClassCount = z;
        this.messageLogLevel = logLevel;
    }

    @Override // org.neo4j.gds.ml.models.ClassifierTrainer
    public LogisticRegressionClassifier train(Features features, HugeIntArray hugeIntArray, ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        LogisticRegressionClassifier from = LogisticRegressionClassifier.from(this.reduceClassCount ? LogisticRegressionData.withReducedClassCount(features.featureDimension(), this.numberOfClasses) : LogisticRegressionData.standard(features.featureDimension(), this.numberOfClasses));
        new Training(this.trainConfig, this.progressTracker, this.messageLogLevel, readOnlyHugeLongArray.size(), this.terminationFlag).train(new LogisticRegressionObjective(from, this.trainConfig.penalty(), features, hugeIntArray, this.trainConfig.focusWeight(), this.trainConfig.initializeClassWeights(this.numberOfClasses)), () -> {
            return BatchQueue.fromArray(readOnlyHugeLongArray, this.trainConfig.batchSize());
        }, this.concurrency);
        return from;
    }
}
