package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.ReadOnlyHugeLongIdentityArray;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.ImmutableModelStats;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
import org.neo4j.gds.ml.metrics.SignedProbabilities;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.ClassifierTrainerFactory;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.class */
public final class LinkPredictionTrain {
    private final Graph trainGraph;
    private final Graph validationGraph;
    private final LinkPredictionTrainingPipeline pipeline;
    private final LinkPredictionTrainConfig config;
    private final LocalIdMap classIdMap = makeClassIdMap();
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    private static LocalIdMap makeClassIdMap() {
        return LocalIdMap.of(new long[]{0, 1});
    }

    public LinkPredictionTrain(Graph graph, Graph graph2, LinkPredictionTrainingPipeline linkPredictionTrainingPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.trainGraph = graph;
        this.validationGraph = graph2;
        this.pipeline = linkPredictionTrainingPipeline;
        this.config = linkPredictionTrainConfig;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    public static List<Task> progressTasks(long j, LinkPredictionSplitConfig linkPredictionSplitConfig, int i) {
        ExpectedSetSizes expectedSetSizes = linkPredictionSplitConfig.expectedSetSizes(j);
        ArrayList arrayList = new ArrayList();
        arrayList.add(Tasks.leaf("Extract train features", expectedSetSizes.trainSize() * 3));
        arrayList.addAll(CrossValidation.progressTasks(linkPredictionSplitConfig.validationFolds(), i, expectedSetSizes.trainSize()));
        arrayList.add(ClassifierTrainer.progressTask("Train best model", expectedSetSizes.trainSize() * 5));
        arrayList.add(Tasks.leaf("Compute train metrics", expectedSetSizes.trainSize()));
        arrayList.add(Tasks.task("Evaluate on test data", Tasks.leaf("Extract test features", expectedSetSizes.testSize() * 3), new Task[]{Tasks.leaf("Compute test metrics", expectedSetSizes.testSize())}));
        return arrayList;
    }

    public LinkPredictionTrainResult compute() {
        this.progressTracker.beginSubTask("Extract train features");
        FeaturesAndLabels extractFeaturesAndLabels = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.trainGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker, this.terminationFlag);
        ReadOnlyHugeLongIdentityArray readOnlyHugeLongIdentityArray = new ReadOnlyHugeLongIdentityArray(extractFeaturesAndLabels.size());
        this.progressTracker.endSubTask("Extract train features");
        TrainingStatistics trainingStatistics = new TrainingStatistics(this.config.metrics());
        findBestModelCandidate(extractFeaturesAndLabels, readOnlyHugeLongIdentityArray, trainingStatistics);
        this.progressTracker.beginSubTask("Train best model");
        TrainerConfig bestParameters = trainingStatistics.bestParameters();
        LogLevel logLevel = LogLevel.INFO;
        List<Metric> metrics = this.config.metrics();
        Objects.requireNonNull(trainingStatistics);
        Classifier trainModel = trainModel(extractFeaturesAndLabels, readOnlyHugeLongIdentityArray, bestParameters, logLevel, ModelSpecificMetricsHandler.of(metrics, (v1, v2) -> {
            r6.addTestScore(v1, v2);
        }));
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Compute train metrics");
        Objects.requireNonNull(trainingStatistics);
        computeTrainMetric(extractFeaturesAndLabels, trainModel, readOnlyHugeLongIdentityArray, trainingStatistics::addOuterTrainScore, this.progressTracker);
        this.progressTracker.endSubTask("Compute train metrics");
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on full train set: %s", new Object[]{trainingStatistics.winningModelOuterTrainMetrics()}));
        this.progressTracker.beginSubTask("Evaluate on test data");
        computeTestMetric(trainModel, trainingStatistics);
        this.progressTracker.endSubTask("Evaluate on test data");
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on test set: %s", new Object[]{trainingStatistics.winningModelTestMetrics()}));
        return ImmutableLinkPredictionTrainResult.of(trainModel, trainingStatistics);
    }

    private void findBestModelCandidate(FeaturesAndLabels featuresAndLabels, ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainingStatistics trainingStatistics) {
        RandomSearch randomSearch = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.autoTuningConfig().maxTrials(), this.config.randomSeed());
        CrossValidation crossValidation = new CrossValidation(this.progressTracker, this.terminationFlag, this.config.metrics(), this.pipeline.splitConfig().validationFolds(), this.config.randomSeed(), (readOnlyHugeLongArray2, trainerConfig, modelSpecificMetricsHandler, logLevel) -> {
            return trainModel(featuresAndLabels, readOnlyHugeLongArray2, trainerConfig, logLevel, modelSpecificMetricsHandler);
        }, (readOnlyHugeLongArray3, classifier, metricConsumer) -> {
            computeTrainMetric(featuresAndLabels, classifier, readOnlyHugeLongArray3, metricConsumer, ProgressTracker.NULL_TRACKER);
        });
        HugeIntArray labels = featuresAndLabels.labels();
        Objects.requireNonNull(labels);
        crossValidation.selectModel(readOnlyHugeLongArray, labels::get, new TreeSet(this.classIdMap.originalIdsList()), trainingStatistics, randomSearch);
    }

    @NotNull
    private Classifier trainModel(FeaturesAndLabels featuresAndLabels, ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainerConfig trainerConfig, LogLevel logLevel, ModelSpecificMetricsHandler modelSpecificMetricsHandler) {
        return ClassifierTrainerFactory.create(trainerConfig, this.classIdMap.size(), this.terminationFlag, this.progressTracker, logLevel, this.config.concurrency(), this.config.randomSeed(), true, modelSpecificMetricsHandler).train(featuresAndLabels.features(), featuresAndLabels.labels(), readOnlyHugeLongArray);
    }

    private void computeTestMetric(Classifier classifier, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Extract test features");
        FeaturesAndLabels extractFeaturesAndLabels = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.validationGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker, this.terminationFlag);
        this.progressTracker.endSubTask("Extract test features");
        this.progressTracker.beginSubTask("Compute test metrics");
        SignedProbabilities computeFromLabeledData = SignedProbabilities.computeFromLabeledData(extractFeaturesAndLabels.features(), extractFeaturesAndLabels.labels(), classifier, BatchQueue.consecutive(extractFeaturesAndLabels.size()), this.config.concurrency(), this.terminationFlag, this.progressTracker);
        this.config.linkMetrics().forEach(linkMetric -> {
            trainingStatistics.addTestScore(linkMetric, linkMetric.compute(computeFromLabeledData, this.config.negativeClassWeight()));
        });
        this.progressTracker.endSubTask("Compute test metrics");
    }

    private void computeTrainMetric(FeaturesAndLabels featuresAndLabels, Classifier classifier, ReadOnlyHugeLongArray readOnlyHugeLongArray, MetricConsumer metricConsumer, ProgressTracker progressTracker) {
        SignedProbabilities computeFromLabeledData = SignedProbabilities.computeFromLabeledData(featuresAndLabels.features(), featuresAndLabels.labels(), classifier, BatchQueue.fromArray(readOnlyHugeLongArray), this.config.concurrency(), this.terminationFlag, progressTracker);
        this.config.linkMetrics().forEach(linkMetric -> {
            metricConsumer.consume(linkMetric, linkMetric.compute(computeFromLabeledData, this.config.negativeClassWeight()));
        });
    }

    public static MemoryEstimation estimate(LinkPredictionTrainingPipeline linkPredictionTrainingPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig) {
        LinkPredictionSplitConfig splitConfig = linkPredictionTrainingPipeline.splitConfig();
        MemoryEstimations.Builder builder = MemoryEstimations.builder(LinkPredictionTrain.class.getSimpleName());
        MemoryRange of = MemoryRange.of(10L, 500L);
        int size = linkPredictionTrainConfig.linkMetrics().size();
        return builder.max("Features and labels", List.of(LinkFeaturesAndLabelsExtractor.estimate(of, map -> {
            return ((Long) map.get(splitConfig.trainRelationshipType())).longValue();
        }, "Train"), LinkFeaturesAndLabelsExtractor.estimate(of, map2 -> {
            return ((Long) map2.get(splitConfig.testRelationshipType())).longValue();
        }, "Test"))).add(estimateTrainingAndEvaluation(linkPredictionTrainingPipeline, of, size)).add("Outer train stats map", TrainingStatistics.memoryEstimationStatsMap(size, 1, 1)).add("Test stats map", TrainingStatistics.memoryEstimationStatsMap(size, 1, 1)).fixed("Best model stats", MemoryRange.of(MemoryUsage.sizeOfInstance(ImmutableModelStats.class)).times(2L).add(16L).times(size)).build();
    }

    private static MemoryEstimation estimateTrainingAndEvaluation(LinkPredictionTrainingPipeline linkPredictionTrainingPipeline, MemoryRange memoryRange, int i) {
        LinkPredictionSplitConfig splitConfig = linkPredictionTrainingPipeline.splitConfig();
        return MemoryEstimations.builder("model selection").add("Cross-Validation splitting", StratifiedKFoldSplitter.memoryEstimation(splitConfig.validationFolds(), graphDimensions -> {
            return ((Long) graphDimensions.relationshipCounts().get(splitConfig.trainRelationshipType())).longValue();
        })).add(MemoryEstimations.maxEstimation("Max over model candidates", (List) linkPredictionTrainingPipeline.trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).flatMap((v0) -> {
            return v0.streamCornerCaseConfigs();
        }).map(trainerConfig -> {
            return MemoryEstimations.builder("Train and evaluate model").fixed("Stats map builder train", ModelStatsBuilder.sizeInBytes(i)).fixed("Stats map builder validation", ModelStatsBuilder.sizeInBytes(i)).max("Train model and compute train metrics", List.of(estimateTraining(linkPredictionTrainingPipeline.splitConfig(), trainerConfig, memoryRange), estimateComputeTrainMetrics(linkPredictionTrainingPipeline.splitConfig()))).build();
        }).collect(Collectors.toList()))).add("Inner train stats map", TrainingStatistics.memoryEstimationStatsMap(i, linkPredictionTrainingPipeline.numberOfModelSelectionTrials(), 1)).add("Validation stats map", TrainingStatistics.memoryEstimationStatsMap(i, linkPredictionTrainingPipeline.numberOfModelSelectionTrials(), 1)).build();
    }

    private static MemoryEstimation estimateTraining(LinkPredictionSplitConfig linkPredictionSplitConfig, TrainerConfig trainerConfig, MemoryRange memoryRange) {
        return MemoryEstimations.setup("Training", graphDimensions -> {
            return ClassifierTrainerFactory.memoryEstimation(trainerConfig, j -> {
                return ((Long) graphDimensions.relationshipCounts().get(linkPredictionSplitConfig.trainRelationshipType())).longValue();
            }, 2, memoryRange, true);
        });
    }

    private static MemoryEstimation estimateComputeTrainMetrics(LinkPredictionSplitConfig linkPredictionSplitConfig) {
        return MemoryEstimations.builder("Compute train metrics").perGraphDimension("Sorted probabilities", (graphDimensions, num) -> {
            return MemoryRange.of(SignedProbabilities.estimateMemory(((Long) graphDimensions.relationshipCounts().get(linkPredictionSplitConfig.trainRelationshipType())).longValue()));
        }).build();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 102230:
                if (implMethodName.equals("get")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/eclipse/collections/api/block/function/primitive/LongToLongFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(J)J") && serializedLambda.getImplClass().equals("org/neo4j/gds/core/utils/paged/HugeIntArray") && serializedLambda.getImplMethodSignature().equals("(J)I")) {
                    HugeIntArray hugeIntArray = (HugeIntArray) serializedLambda.getCapturedArg(0);
                    return hugeIntArray::get;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
