package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.TreeSet;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.ClassifierTrainerFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodeClassification.ClassificationMetricComputer;
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.PipelineTrainer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.LabelsAndClassCountsExtractor;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.class */
public final class NodeClassificationTrain implements PipelineTrainer<NodeClassificationTrainResult> {
    private final NodeClassificationTrainingPipeline pipeline;
    private final NodeClassificationPipelineTrainConfig trainConfig;
    private final HugeIntArray targets;
    private final LocalIdMap classIdMap;
    private final IdMap nodeIdMap;
    private final List<Metric> metrics;
    private final List<ClassificationMetric> classificationMetrics;
    private final LongMultiSet classCounts;
    private final NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer;
    private final ProgressTracker progressTracker;
    private TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

    public static MemoryEstimation estimate(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ModelCatalog modelCatalog) {
        nodeClassificationTrainingPipeline.validateTrainingParameterSpace();
        return MemoryEstimations.maxEstimation("Node Classification Train Pipeline", List.of(NodePropertyStepExecutor.estimateNodePropertySteps(modelCatalog, nodeClassificationTrainingPipeline.nodePropertySteps(), nodeClassificationPipelineTrainConfig.nodeLabels(), nodeClassificationPipelineTrainConfig.relationshipTypes()), MemoryEstimations.builder().add("Training", estimateExcludingNodePropertySteps(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig)).build()));
    }

    private static MemoryEstimation estimateExcludingNodePropertySteps(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        int i = 1000;
        int i2 = 500;
        NodePropertyPredictionSplitConfig splitConfig = nodeClassificationTrainingPipeline.splitConfig();
        double testFraction = splitConfig.testFraction();
        Objects.requireNonNull(splitConfig);
        LongUnaryOperator longUnaryOperator = splitConfig::foldTrainSetSize;
        Objects.requireNonNull(splitConfig);
        MemoryEstimation modelTrainAndEvaluateMemoryUsage = modelTrainAndEvaluateMemoryUsage(nodeClassificationTrainingPipeline, 1000, 500, longUnaryOperator, splitConfig::foldTestSetSize);
        Objects.requireNonNull(splitConfig);
        LongUnaryOperator longUnaryOperator2 = splitConfig::trainSetSize;
        Objects.requireNonNull(splitConfig);
        MemoryEstimations.Builder add = MemoryEstimations.builder().perNode("global targets", HugeIntArray::memoryEstimation).rangePerNode("global class counts", j -> {
            return MemoryRange.of(16L, i * 8);
        }).add("metrics", ClassificationMetricSpecification.memoryEstimation(1000)).perNode("node IDs", HugeLongArray::memoryEstimation).add("outer split", FractionSplitter.estimate(1.0d - testFraction)).add("inner split", StratifiedKFoldSplitter.memoryEstimationForNodeSet(splitConfig.validationFolds(), 1.0d - testFraction)).add("stats map train", TrainingStatistics.memoryEstimationStatsMap(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationTrainingPipeline.numberOfModelSelectionTrials())).add("stats map validation", TrainingStatistics.memoryEstimationStatsMap(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationTrainingPipeline.numberOfModelSelectionTrials())).add("max of model selection and best model evaluation", MemoryEstimations.maxEstimation(List.of(modelTrainAndEvaluateMemoryUsage, MemoryEstimations.delegateEstimation(modelTrainAndEvaluateMemoryUsage(nodeClassificationTrainingPipeline, 1000, 500, longUnaryOperator2, splitConfig::testSetSize), "best model evaluation"))));
        if (!nodeClassificationTrainingPipeline.trainingParameterSpace().get(TrainingMethod.RandomForestClassification).isEmpty()) {
            add.perGraphDimension("cached feature vectors", (graphDimensions, num) -> {
                return MemoryRange.of(HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(10L)), HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(i2)));
            });
        }
        return add.build();
    }

    public static Task progressTask(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, long j) {
        NodePropertyPredictionSplitConfig splitConfig = nodeClassificationTrainingPipeline.splitConfig();
        long trainSetSize = splitConfig.trainSetSize(j);
        long testSetSize = splitConfig.testSetSize(j);
        int validationFolds = splitConfig.validationFolds();
        ArrayList arrayList = new ArrayList();
        arrayList.add(NodePropertyStepExecutor.tasks(nodeClassificationTrainingPipeline.nodePropertySteps(), j));
        arrayList.addAll(CrossValidation.progressTasks(validationFolds, nodeClassificationTrainingPipeline.numberOfModelSelectionTrials(), trainSetSize));
        arrayList.add(ClassifierTrainer.progressTask("Train best model", 5 * trainSetSize));
        arrayList.add(Tasks.leaf("Evaluate on train data", trainSetSize));
        arrayList.add(Tasks.leaf("Evaluate on test data", testSetSize));
        arrayList.add(ClassifierTrainer.progressTask("Retrain best model", 5 * j));
        return Tasks.task("Node Classification Train Pipeline", arrayList);
    }

    @NotNull
    private static MemoryEstimation modelTrainAndEvaluateMemoryUsage(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, int i, int i2, LongUnaryOperator longUnaryOperator, LongUnaryOperator longUnaryOperator2) {
        return MemoryEstimations.builder("model selection").max((List) nodeClassificationTrainingPipeline.trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).flatMap((v0) -> {
            return v0.streamCornerCaseConfigs();
        }).map(trainerConfig -> {
            return MemoryEstimations.setup("max of training and evaluation", graphDimensions -> {
                return MemoryEstimations.maxEstimation(List.of(ClassifierTrainerFactory.memoryEstimation(trainerConfig, longUnaryOperator, (int) Math.min(i, graphDimensions.nodeCount()), MemoryRange.of(i2), false), ClassificationMetricComputer.estimateEvaluation(trainerConfig, (int) Math.min(trainerConfig instanceof LogisticRegressionTrainConfig ? ((LogisticRegressionTrainConfig) trainerConfig).batchSize() : 0, graphDimensions.nodeCount()), longUnaryOperator, longUnaryOperator2, (int) Math.min(i, graphDimensions.nodeCount()), i2, false)));
            });
        }).collect(Collectors.toList())).build();
    }

    public static NodeClassificationTrain create(GraphStore graphStore, NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        Graph graph = graphStore.getGraph(nodeClassificationPipelineTrainConfig.nodeLabelIdentifiers(graphStore));
        nodeClassificationTrainingPipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
        LabelsAndClassCountsExtractor.LabelsAndClassCounts extractLabelsAndClassCounts = LabelsAndClassCountsExtractor.extractLabelsAndClassCounts(graph.nodeProperties(nodeClassificationPipelineTrainConfig.targetProperty()), graph.nodeCount());
        LongMultiSet classCounts = extractLabelsAndClassCounts.classCounts();
        LocalIdMap ofSorted = LocalIdMap.ofSorted(classCounts.keys());
        return new NodeClassificationTrain(nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig, extractLabelsAndClassCounts.labels(), ofSorted, graph, nodeClassificationPipelineTrainConfig.metrics(ofSorted, classCounts), NodeClassificationPipelineTrainConfig.classificationMetrics(nodeClassificationPipelineTrainConfig.metrics(ofSorted, classCounts)), classCounts, nodeFeatureProducer, progressTracker);
    }

    private NodeClassificationTrain(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, HugeIntArray hugeIntArray, LocalIdMap localIdMap, IdMap idMap, List<Metric> list, List<ClassificationMetric> list2, LongMultiSet longMultiSet, NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        this.pipeline = nodeClassificationTrainingPipeline;
        this.nodeIdMap = idMap;
        this.classificationMetrics = list2;
        this.nodeFeatureProducer = nodeFeatureProducer;
        this.trainConfig = nodeClassificationPipelineTrainConfig;
        this.targets = hugeIntArray;
        this.classIdMap = localIdMap;
        this.metrics = list;
        this.classCounts = longMultiSet;
        this.progressTracker = progressTracker;
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineTrainer
    public void setTerminationFlag(TerminationFlag terminationFlag) {
        this.terminationFlag = terminationFlag;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.pipeline.PipelineTrainer
    public NodeClassificationTrainResult run() {
        this.progressTracker.beginSubTask();
        NodePropertyPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        int concurrency = this.trainConfig.concurrency();
        long nodeCount = this.nodeIdMap.nodeCount();
        ProgressTracker progressTracker = this.progressTracker;
        IdMap idMap = this.nodeIdMap;
        Objects.requireNonNull(idMap);
        LongUnaryOperator longUnaryOperator = idMap::toOriginalNodeId;
        IdMap idMap2 = this.nodeIdMap;
        Objects.requireNonNull(idMap2);
        NodeSplitter.NodeSplits split = new NodeSplitter(concurrency, nodeCount, progressTracker, longUnaryOperator, idMap2::toMappedNodeId).split(splitConfig.testFraction(), splitConfig.validationFolds(), this.trainConfig.randomSeed());
        TrainingStatistics trainingStatistics = new TrainingStatistics(this.metrics);
        Features procedureFeatures = this.nodeFeatureProducer.procedureFeatures(this.pipeline);
        findBestModelCandidate(split.outerSplit().trainSet(), procedureFeatures, trainingStatistics);
        evaluateBestModel(split.outerSplit(), procedureFeatures, trainingStatistics);
        Classifier retrainBestModel = retrainBestModel(split.allTrainingExamples(), procedureFeatures, trainingStatistics.bestParameters());
        this.progressTracker.endSubTask();
        return ImmutableNodeClassificationTrainResult.of(retrainBestModel, trainingStatistics, this.classIdMap, this.classCounts);
    }

    private void findBestModelCandidate(ReadOnlyHugeLongArray readOnlyHugeLongArray, Features features, TrainingStatistics trainingStatistics) {
        CrossValidation crossValidation = new CrossValidation(this.progressTracker, this.terminationFlag, this.metrics, this.pipeline.splitConfig().validationFolds(), this.trainConfig.randomSeed(), (readOnlyHugeLongArray2, trainerConfig, modelSpecificMetricsHandler, logLevel) -> {
            return trainModel(readOnlyHugeLongArray2, trainerConfig, features, logLevel, modelSpecificMetricsHandler);
        }, (readOnlyHugeLongArray3, classifier, metricConsumer) -> {
            registerMetricScores(readOnlyHugeLongArray3, classifier, features, metricConsumer, ProgressTracker.NULL_TRACKER);
        });
        RandomSearch randomSearch = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.numberOfModelSelectionTrials(), this.trainConfig.randomSeed());
        TreeSet treeSet = new TreeSet();
        for (long j : this.classCounts.keys()) {
            treeSet.add(Long.valueOf(j));
        }
        HugeIntArray hugeIntArray = this.targets;
        Objects.requireNonNull(hugeIntArray);
        crossValidation.selectModel(readOnlyHugeLongArray, hugeIntArray::get, treeSet, trainingStatistics, randomSearch);
    }

    private void registerMetricScores(ReadOnlyHugeLongArray readOnlyHugeLongArray, Classifier classifier, Features features, MetricConsumer metricConsumer, ProgressTracker progressTracker) {
        ClassificationMetricComputer forEvaluationSet = ClassificationMetricComputer.forEvaluationSet(features, this.targets, readOnlyHugeLongArray, classifier, this.trainConfig.concurrency(), this.terminationFlag, progressTracker);
        this.classificationMetrics.forEach(classificationMetric -> {
            metricConsumer.consume(classificationMetric, forEvaluationSet.score(classificationMetric));
        });
    }

    private void evaluateBestModel(TrainingExamplesSplit trainingExamplesSplit, Features features, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Train best model");
        ModelCandidateStats bestCandidate = trainingStatistics.bestCandidate();
        ReadOnlyHugeLongArray trainSet = trainingExamplesSplit.trainSet();
        TrainerConfig trainerConfig = bestCandidate.trainerConfig();
        LogLevel logLevel = LogLevel.INFO;
        List<Metric> list = this.metrics;
        Objects.requireNonNull(trainingStatistics);
        Classifier trainModel = trainModel(trainSet, trainerConfig, features, logLevel, ModelSpecificMetricsHandler.of(list, (v1, v2) -> {
            r6.addTestScore(v1, v2);
        }));
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Evaluate on train data");
        this.progressTracker.setSteps(trainingExamplesSplit.trainSet().size());
        ReadOnlyHugeLongArray trainSet2 = trainingExamplesSplit.trainSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(trainSet2, trainModel, features, trainingStatistics::addOuterTrainScore, this.progressTracker);
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on full train set: %s", new Object[]{trainingStatistics.winningModelOuterTrainMetrics()}));
        this.progressTracker.endSubTask("Evaluate on train data");
        this.progressTracker.beginSubTask("Evaluate on test data");
        this.progressTracker.setSteps(trainingExamplesSplit.testSet().size());
        ReadOnlyHugeLongArray testSet = trainingExamplesSplit.testSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(testSet, trainModel, features, trainingStatistics::addTestScore, this.progressTracker);
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on test set: %s", new Object[]{trainingStatistics.winningModelTestMetrics()}));
        this.progressTracker.endSubTask("Evaluate on test data");
    }

    private Classifier retrainBestModel(ReadOnlyHugeLongArray readOnlyHugeLongArray, Features features, TrainerConfig trainerConfig) {
        this.progressTracker.beginSubTask("Retrain best model");
        Classifier trainModel = trainModel(readOnlyHugeLongArray, trainerConfig, features, LogLevel.INFO, ModelSpecificMetricsHandler.NOOP);
        this.progressTracker.endSubTask("Retrain best model");
        return trainModel;
    }

    private Classifier trainModel(ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainerConfig trainerConfig, Features features, LogLevel logLevel, ModelSpecificMetricsHandler modelSpecificMetricsHandler) {
        return ClassifierTrainerFactory.create(trainerConfig, this.classIdMap.size(), this.terminationFlag, this.progressTracker, logLevel, this.trainConfig.concurrency(), this.trainConfig.randomSeed(), false, modelSpecificMetricsHandler).train(features, this.targets, readOnlyHugeLongArray);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 102230:
                if (implMethodName.equals("get")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/eclipse/collections/api/block/function/primitive/LongToLongFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(J)J") && serializedLambda.getImplClass().equals("org/neo4j/gds/core/utils/paged/HugeIntArray") && serializedLambda.getImplMethodSignature().equals("(J)I")) {
                    HugeIntArray hugeIntArray = (HugeIntArray) serializedLambda.getCapturedArg(0);
                    return hugeIntArray::get;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
