package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.immutables.value.Value;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.collections.ReadOnlyHugeLongIdentityArray;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.ml.nodemodels.BestMetricData;
import org.neo4j.gds.ml.nodemodels.BestModelStats;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.nodemodels.StatsMap;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.Trainer;
import org.neo4j.gds.models.TrainerConfig;
import org.neo4j.gds.models.TrainerFactory;
import org.neo4j.gds.models.TrainingMethod;
import org.neo4j.gds.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.models.logisticregression.LogisticRegressionTrainer;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.class */
public class LinkPredictionTrain extends Algorithm<LinkPredictionTrainResult> {
    public static final String MODEL_TYPE = "LinkPrediction";
    private final Graph trainGraph;
    private final Graph validationGraph;
    private final LinkPredictionPipeline pipeline;
    private final LinkPredictionTrainConfig config;
    private final LocalIdMap classIdMap;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        TrainerConfig bestParameters();

        Map<LinkMetric, List<ModelStats>> trainStats();

        Map<LinkMetric, List<ModelStats>> validationStats();

        static ModelSelectResult of(TrainerConfig trainerConfig, Map<LinkMetric, List<ModelStats>> map, Map<LinkMetric, List<ModelStats>> map2) {
            return ImmutableModelSelectResult.of(trainerConfig, map, map2);
        }

        @Value.Derived
        default Map<String, Object> toMap() {
            Function function = map -> {
                return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
                    return ((LinkMetric) entry.getKey()).name();
                }, entry2 -> {
                    return ((List) entry2.getValue()).stream().map((v0) -> {
                        return v0.toMap();
                    });
                }));
            };
            return Map.of("bestParameters", bestParameters().toMap(), "trainStats", function.apply(trainStats()), "validationStats", function.apply(validationStats()));
        }
    }

    public static LocalIdMap makeClassIdMap() {
        LocalIdMap localIdMap = new LocalIdMap();
        localIdMap.toMapped(0L);
        localIdMap.toMapped(1L);
        return localIdMap;
    }

    public LinkPredictionTrain(Graph graph, Graph graph2, LinkPredictionPipeline linkPredictionPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig, ProgressTracker progressTracker) {
        super(progressTracker);
        this.trainGraph = graph;
        this.validationGraph = graph2;
        this.pipeline = linkPredictionPipeline;
        this.config = linkPredictionTrainConfig;
        this.classIdMap = makeClassIdMap();
    }

    public static Task progressTask() {
        return Tasks.task(LinkPredictionTrain.class.getSimpleName(), Tasks.leaf("extract train features"), new Task[]{Tasks.leaf("select model"), Trainer.progressTask("train best model"), Tasks.leaf("compute train metrics"), Tasks.task("evaluate on test data", Tasks.leaf("extract test features"), new Task[]{Tasks.leaf("compute test metrics")})});
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public LinkPredictionTrainResult m13compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.beginSubTask("extract train features");
        FeaturesAndLabels extractFeaturesAndLabels = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.trainGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker);
        ReadOnlyHugeLongIdentityArray readOnlyHugeLongIdentityArray = new ReadOnlyHugeLongIdentityArray(extractFeaturesAndLabels.size());
        this.progressTracker.endSubTask("extract train features");
        this.progressTracker.beginSubTask("select model");
        ModelSelectResult modelSelect = modelSelect(extractFeaturesAndLabels, readOnlyHugeLongIdentityArray);
        TrainerConfig bestParameters = modelSelect.bestParameters();
        this.progressTracker.endSubTask("select model");
        this.progressTracker.beginSubTask("train best model");
        Classifier trainModel = trainModel(extractFeaturesAndLabels, readOnlyHugeLongIdentityArray, bestParameters, this.progressTracker);
        this.progressTracker.endSubTask("train best model");
        this.progressTracker.beginSubTask("compute train metrics");
        Map<LinkMetric, Double> computeTrainMetric = computeTrainMetric(extractFeaturesAndLabels, trainModel, readOnlyHugeLongIdentityArray, this.progressTracker);
        this.progressTracker.endSubTask("compute train metrics");
        this.progressTracker.beginSubTask("evaluate on test data");
        Map<LinkMetric, Double> computeTestMetric = computeTestMetric(trainModel);
        this.progressTracker.endSubTask("evaluate on test data");
        Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> createModel = createModel(bestParameters, trainModel.data(), combineBestParameterMetrics(modelSelect, computeTrainMetric, computeTestMetric));
        this.progressTracker.endSubTask();
        return LinkPredictionTrainResult.of(createModel, modelSelect);
    }

    @NotNull
    private Classifier trainModel(FeaturesAndLabels featuresAndLabels, ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainerConfig trainerConfig, ProgressTracker progressTracker) {
        return TrainerFactory.create(trainerConfig, this.classIdMap, this.terminationFlag, progressTracker, this.config.concurrency(), this.config.randomSeed(), true).train(featuresAndLabels.features(), featuresAndLabels.labels(), readOnlyHugeLongArray);
    }

    private ModelSelectResult modelSelect(FeaturesAndLabels featuresAndLabels, ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        List<TrainingExamplesSplit> trainValidationSplits = trainValidationSplits(readOnlyHugeLongArray, featuresAndLabels.labels());
        Map<LinkMetric, List<ModelStats>> initStatsMap = initStatsMap();
        Map<LinkMetric, List<ModelStats>> initStatsMap2 = initStatsMap();
        this.progressTracker.setVolume(this.pipeline.numberOfModelCandidates());
        this.pipeline.trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).forEach(trainerConfig -> {
            LinkModelStatsBuilder linkModelStatsBuilder = new LinkModelStatsBuilder(trainerConfig, this.pipeline.splitConfig().validationFolds());
            LinkModelStatsBuilder linkModelStatsBuilder2 = new LinkModelStatsBuilder(trainerConfig, this.pipeline.splitConfig().validationFolds());
            Iterator it = trainValidationSplits.iterator();
            while (it.hasNext()) {
                TrainingExamplesSplit trainingExamplesSplit = (TrainingExamplesSplit) it.next();
                HugeLongArray trainSet = trainingExamplesSplit.trainSet();
                HugeLongArray testSet = trainingExamplesSplit.testSet();
                Classifier trainModel = trainModel(featuresAndLabels, ReadOnlyHugeLongArray.of(trainSet), trainerConfig, ProgressTracker.NULL_TRACKER);
                Map<LinkMetric, Double> computeTrainMetric = computeTrainMetric(featuresAndLabels, trainModel, ReadOnlyHugeLongArray.of(trainSet), ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(linkModelStatsBuilder);
                computeTrainMetric.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<LinkMetric, Double> computeTrainMetric2 = computeTrainMetric(featuresAndLabels, trainModel, ReadOnlyHugeLongArray.of(testSet), ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(linkModelStatsBuilder2);
                computeTrainMetric2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
            }
            this.config.metrics().forEach(linkMetric -> {
                ((List) initStatsMap2.get(linkMetric)).add(linkModelStatsBuilder2.modelStats(linkMetric));
                ((List) initStatsMap.get(linkMetric)).add(linkModelStatsBuilder.modelStats(linkMetric));
            });
            this.progressTracker.logProgress();
        });
        return ModelSelectResult.of(((ModelStats) Collections.max(initStatsMap2.get(this.config.metrics().get(0)), ModelStats.COMPARE_AVERAGE)).params(), initStatsMap, initStatsMap2);
    }

    private Map<LinkMetric, Double> computeTestMetric(Classifier classifier) {
        this.progressTracker.beginSubTask("extract test features");
        FeaturesAndLabels extractFeaturesAndLabels = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.validationGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker);
        this.progressTracker.endSubTask("extract test features");
        this.progressTracker.beginSubTask("compute test metrics");
        Map<LinkMetric, Double> computeMetric = LinkPredictionEvaluationMetricComputer.computeMetric(extractFeaturesAndLabels, classifier, new BatchQueue(extractFeaturesAndLabels.size()), this.config, this.progressTracker, this.terminationFlag);
        this.progressTracker.endSubTask("compute test metrics");
        return computeMetric;
    }

    private static Map<LinkMetric, BestMetricData> combineBestParameterMetrics(ModelSelectResult modelSelectResult, Map<LinkMetric, Double> map, Map<LinkMetric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), linkMetric -> {
            return BestMetricData.of(BestModelStats.findBestModelStats(modelSelectResult.trainStats().get(linkMetric), modelSelectResult.bestParameters()), BestModelStats.findBestModelStats(modelSelectResult.validationStats().get(linkMetric), modelSelectResult.bestParameters()), ((Double) map.get(linkMetric)).doubleValue(), ((Double) map2.get(linkMetric)).doubleValue());
        }));
    }

    private List<TrainingExamplesSplit> trainValidationSplits(ReadOnlyHugeLongArray readOnlyHugeLongArray, HugeLongArray hugeLongArray) {
        return new StratifiedKFoldSplitter(this.pipeline.splitConfig().validationFolds(), readOnlyHugeLongArray, ReadOnlyHugeLongArray.of(hugeLongArray), this.config.randomSeed()).splits();
    }

    private static Map<LinkMetric, List<ModelStats>> initStatsMap() {
        HashMap hashMap = new HashMap();
        hashMap.put(LinkMetric.AUCPR, new ArrayList());
        return hashMap;
    }

    private Map<LinkMetric, Double> computeTrainMetric(FeaturesAndLabels featuresAndLabels, Classifier classifier, ReadOnlyHugeLongArray readOnlyHugeLongArray, ProgressTracker progressTracker) {
        return LinkPredictionEvaluationMetricComputer.computeMetric(featuresAndLabels, classifier, new HugeBatchQueue(readOnlyHugeLongArray), this.config, progressTracker, this.terminationFlag);
    }

    private Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> createModel(TrainerConfig trainerConfig, Classifier.ClassifierData classifierData, Map<LinkMetric, BestMetricData> map) {
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.trainGraph.schema(), classifierData, this.config, LinkPredictionModelInfo.of(trainerConfig, map, this.pipeline.copy()));
    }

    public void release() {
    }

    public static MemoryEstimation estimate(LinkPredictionPipeline linkPredictionPipeline, LinkPredictionTrainConfig linkPredictionTrainConfig) {
        LinkPredictionSplitConfig splitConfig = linkPredictionPipeline.splitConfig();
        MemoryEstimations.Builder builder = MemoryEstimations.builder(LinkPredictionTrain.class);
        MemoryRange of = MemoryRange.of(10L, 500L);
        int size = linkPredictionTrainConfig.metrics().size();
        return builder.max("Features and labels", List.of(LinkFeaturesAndLabelsExtractor.estimate(of, map -> {
            return ((Long) map.get(RelationshipType.of(splitConfig.trainRelationshipType()))).longValue();
        }, "Train"), LinkFeaturesAndLabelsExtractor.estimate(of, map2 -> {
            return ((Long) map2.get(RelationshipType.of(splitConfig.testRelationshipType()))).longValue();
        }, "Test"))).add(estimateModelSelection(linkPredictionPipeline, of, size)).add("Outer train stats map", StatsMap.memoryEstimation(size, 1, 1)).add("Test stats map", StatsMap.memoryEstimation(size, 1, 1)).fixed("Best model stats", size * BestMetricData.estimateMemory()).build();
    }

    private static MemoryEstimation estimateModelSelection(LinkPredictionPipeline linkPredictionPipeline, MemoryRange memoryRange, int i) {
        LinkPredictionSplitConfig splitConfig = linkPredictionPipeline.splitConfig();
        return MemoryEstimations.builder("model selection").add("Cross-Validation splitting", StratifiedKFoldSplitter.memoryEstimation(splitConfig.validationFolds(), graphDimensions -> {
            return ((Long) graphDimensions.relationshipCounts().get(RelationshipType.of(splitConfig.trainRelationshipType()))).longValue();
        })).add(MemoryEstimations.maxEstimation("Max over model candidates", (List) linkPredictionPipeline.trainingParameterSpace().get(TrainingMethod.LogisticRegression).stream().map(trainerConfig -> {
            return MemoryEstimations.builder("Train and evaluate model").fixed("Stats map builder train", LinkModelStatsBuilder.sizeInBytes(i)).fixed("Stats map builder validation", LinkModelStatsBuilder.sizeInBytes(i)).max("Train model and compute train metrics", List.of(LogisticRegressionTrainer.estimate((LogisticRegressionTrainConfig) trainerConfig, memoryRange), estimateComputeTrainMetrics(linkPredictionPipeline.splitConfig()))).build();
        }).collect(Collectors.toList()))).add("Inner train stats map", StatsMap.memoryEstimation(i, linkPredictionPipeline.trainingParameterSpace().size(), 1)).add("Validation stats map", StatsMap.memoryEstimation(i, linkPredictionPipeline.trainingParameterSpace().size(), 1)).build();
    }

    private static MemoryEstimation estimateComputeTrainMetrics(LinkPredictionSplitConfig linkPredictionSplitConfig) {
        return MemoryEstimations.builder("Compute train metrics").perGraphDimension("Sorted probabilities", (graphDimensions, num) -> {
            return LinkPredictionEvaluationMetricComputer.estimate(((Long) graphDimensions.relationshipCounts().get(RelationshipType.of(linkPredictionSplitConfig.trainRelationshipType()))).longValue());
        }).build();
    }
}
