package org.neo4j.gds.ml.training;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.SortedSet;
import org.eclipse.collections.api.block.function.primitive.LongToLongFunction;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/training/CrossValidation.class */
public class CrossValidation<MODEL_TYPE> {
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final List<? extends Metric> metrics;
    private final int validationFolds;
    private final Optional<Long> randomSeed;
    private final ModelTrainer<MODEL_TYPE> modelTrainer;
    private final ModelEvaluator<MODEL_TYPE> modelEvaluator;

    @FunctionalInterface
    /* loaded from: input_file:org/neo4j/gds/ml/training/CrossValidation$ModelEvaluator.class */
    public interface ModelEvaluator<MODEL_TYPE> {
        void evaluate(ReadOnlyHugeLongArray readOnlyHugeLongArray, MODEL_TYPE model_type, MetricConsumer metricConsumer);
    }

    @FunctionalInterface
    /* loaded from: input_file:org/neo4j/gds/ml/training/CrossValidation$ModelTrainer.class */
    public interface ModelTrainer<MODEL_TYPE> {
        MODEL_TYPE train(ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainerConfig trainerConfig, ModelSpecificMetricsHandler modelSpecificMetricsHandler, LogLevel logLevel);
    }

    public static List<Task> progressTasks(int i, int i2, long j) {
        return List.of(Tasks.leaf("Create validation folds", Math.max((long) (0.5d * j), 1L)), Tasks.iterativeFixed("Select best model", () -> {
            return List.of(Tasks.leaf("Trial", 5 * i * j));
        }, i2));
    }

    public CrossValidation(ProgressTracker progressTracker, TerminationFlag terminationFlag, List<? extends Metric> list, int i, Optional<Long> optional, ModelTrainer<MODEL_TYPE> modelTrainer, ModelEvaluator<MODEL_TYPE> modelEvaluator) {
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.metrics = list;
        this.validationFolds = i;
        this.randomSeed = optional;
        this.modelTrainer = modelTrainer;
        this.modelEvaluator = modelEvaluator;
    }

    public void selectModel(ReadOnlyHugeLongArray readOnlyHugeLongArray, LongToLongFunction longToLongFunction, SortedSet<Long> sortedSet, TrainingStatistics trainingStatistics, Iterator<TrainerConfig> it) {
        this.progressTracker.beginSubTask("Create validation folds");
        List<TrainingExamplesSplit> splits = new StratifiedKFoldSplitter(this.validationFolds, readOnlyHugeLongArray, longToLongFunction, this.randomSeed, sortedSet).splits();
        this.progressTracker.endSubTask("Create validation folds");
        this.progressTracker.beginSubTask("Select best model");
        int i = 0;
        while (it.hasNext()) {
            this.progressTracker.beginSubTask("Trial");
            this.progressTracker.setSteps(splits.size());
            this.terminationFlag.assertRunning();
            TrainerConfig next = it.next();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Method: %s, Parameters: %s", new Object[]{next.method(), next.toMap()}));
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(splits.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(splits.size());
            ModelSpecificMetricsHandler of = ModelSpecificMetricsHandler.of(this.metrics, modelStatsBuilder);
            int i2 = 1;
            for (TrainingExamplesSplit trainingExamplesSplit : splits) {
                ReadOnlyHugeLongArray trainSet = trainingExamplesSplit.trainSet();
                ReadOnlyHugeLongArray testSet = trainingExamplesSplit.testSet();
                this.progressTracker.logDebug("Starting fold " + i2 + " training");
                MODEL_TYPE train = this.modelTrainer.train(trainSet, next, of, LogLevel.DEBUG);
                this.progressTracker.logDebug("Finished fold " + i2 + " training");
                ModelEvaluator<MODEL_TYPE> modelEvaluator = this.modelEvaluator;
                Objects.requireNonNull(modelStatsBuilder);
                modelEvaluator.evaluate(testSet, train, modelStatsBuilder::update);
                ModelEvaluator<MODEL_TYPE> modelEvaluator2 = this.modelEvaluator;
                Objects.requireNonNull(modelStatsBuilder2);
                modelEvaluator2.evaluate(trainSet, train, modelStatsBuilder2::update);
                this.progressTracker.logSteps(1L);
                i2++;
            }
            trainingStatistics.addCandidateStats(ModelCandidateStats.of(next, modelStatsBuilder2.build(), modelStatsBuilder.build()));
            Map<Metric, Double> validationMetricsAvg = trainingStatistics.validationMetricsAvg(i);
            Map<Metric, Double> trainMetricsAvg = trainingStatistics.trainMetricsAvg(i);
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Main validation metric (%s): %.4f", new Object[]{trainingStatistics.evaluationMetric(), Double.valueOf(trainingStatistics.getMainMetric(i))}));
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Validation metrics: %s", new Object[]{validationMetricsAvg}));
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Training metrics: %s", new Object[]{trainMetricsAvg}));
            i++;
            this.progressTracker.endSubTask("Trial");
        }
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Best trial was Trial %d with main validation metric %.4f", new Object[]{Integer.valueOf(trainingStatistics.getBestTrialIdx() + 1), Double.valueOf(trainingStatistics.getBestTrialScore())}));
        this.progressTracker.endSubTask("Select best model");
    }
}
