package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import java.lang.invoke.SerializedLambda;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import java.util.function.BiConsumer;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
import org.neo4j.gds.ml.metrics.StatsMap;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.ClassifierTrainerFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.FeaturesFactory;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodeClassification.ClassificationMetricComputer;
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
import org.neo4j.gds.ml.pipeline.TrainingStatistics;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.LabelsAndClassCountsExtractor;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.utils.StringFormatting;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.class */
public final class NodeClassificationTrain {
    private final Graph graph;
    private final NodeClassificationPipelineTrainConfig config;
    private final NodeClassificationTrainingPipeline pipeline;
    private final Features features;
    private final HugeLongArray targets;
    private final LocalIdMap classIdMap;
    private final List<ClassificationMetric> metrics;
    private final Multiset<Long> classCounts;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    public static MemoryEstimation estimate(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        int i = 1000;
        int i2 = 500;
        NodePropertyPredictionSplitConfig splitConfig = nodeClassificationTrainingPipeline.splitConfig();
        double testFraction = splitConfig.testFraction();
        Objects.requireNonNull(splitConfig);
        LongUnaryOperator longUnaryOperator = splitConfig::foldTrainSetSize;
        Objects.requireNonNull(splitConfig);
        MemoryEstimation modelTrainAndEvaluateMemoryUsage = modelTrainAndEvaluateMemoryUsage(nodeClassificationTrainingPipeline, 1000, 500, longUnaryOperator, splitConfig::foldTestSetSize);
        Objects.requireNonNull(splitConfig);
        LongUnaryOperator longUnaryOperator2 = splitConfig::trainSetSize;
        Objects.requireNonNull(splitConfig);
        MemoryEstimations.Builder add = MemoryEstimations.builder().perNode("global targets", HugeLongArray::memoryEstimation).rangePerNode("global class counts", j -> {
            return MemoryRange.of(16L, i * 8);
        }).add("metrics", ClassificationMetricSpecification.memoryEstimation(1000)).perNode("node IDs", HugeLongArray::memoryEstimation).add("outer split", FractionSplitter.estimate(1.0d - testFraction)).add("inner split", StratifiedKFoldSplitter.memoryEstimationForNodeSet(splitConfig.validationFolds(), 1.0d - testFraction)).add("stats map train", StatsMap.memoryEstimation(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationTrainingPipeline.numberOfModelSelectionTrials())).add("stats map validation", StatsMap.memoryEstimation(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationTrainingPipeline.numberOfModelSelectionTrials())).add("max of model selection and best model evaluation", MemoryEstimations.maxEstimation(List.of(modelTrainAndEvaluateMemoryUsage, MemoryEstimations.delegateEstimation(modelTrainAndEvaluateMemoryUsage(nodeClassificationTrainingPipeline, 1000, 500, longUnaryOperator2, splitConfig::testSetSize), "best model evaluation"))));
        if (!nodeClassificationTrainingPipeline.trainingParameterSpace().get(TrainingMethod.RandomForest).isEmpty()) {
            add.perGraphDimension("cached feature vectors", (graphDimensions, num) -> {
                return MemoryRange.of(HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(10L)), HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(i2)));
            });
        }
        return add.build();
    }

    public static List<Task> progressTasks(int i, int i2) {
        return List.of(Tasks.leaf("Shuffle and split"), Tasks.iterativeFixed("Select best model", () -> {
            return List.of(Tasks.leaf("Trial", i));
        }, i2), ClassifierTrainer.progressTask("Train best model"), Tasks.leaf("Evaluate on test data"), ClassifierTrainer.progressTask("Retrain best model"));
    }

    @NotNull
    private static MemoryEstimation modelTrainAndEvaluateMemoryUsage(NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, int i, int i2, LongUnaryOperator longUnaryOperator, LongUnaryOperator longUnaryOperator2) {
        return MemoryEstimations.builder("model selection").max((List) nodeClassificationTrainingPipeline.trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).flatMap((v0) -> {
            return v0.streamCornerCaseConfigs();
        }).map(trainerConfig -> {
            return MemoryEstimations.setup("max of training and evaluation", graphDimensions -> {
                return MemoryEstimations.maxEstimation(List.of(ClassifierTrainerFactory.memoryEstimation(trainerConfig, longUnaryOperator, (int) Math.min(i, graphDimensions.nodeCount()), MemoryRange.of(i2), false), ClassificationMetricComputer.estimateEvaluation(trainerConfig, (int) Math.min(trainerConfig instanceof LogisticRegressionTrainConfig ? ((LogisticRegressionTrainConfig) trainerConfig).batchSize() : 0, graphDimensions.nodeCount()), longUnaryOperator, longUnaryOperator2, (int) Math.min(i, graphDimensions.nodeCount()), i2, false)));
            });
        }).collect(Collectors.toList())).build();
    }

    public static NodeClassificationTrain create(Graph graph, NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        LabelsAndClassCountsExtractor.LabelsAndClassCounts extractLabelsAndClassCounts = LabelsAndClassCountsExtractor.extractLabelsAndClassCounts(graph.nodeProperties(nodeClassificationPipelineTrainConfig.targetProperty()), graph.nodeCount());
        Multiset<Long> classCounts = extractLabelsAndClassCounts.classCounts();
        return new NodeClassificationTrain(graph, nodeClassificationTrainingPipeline, nodeClassificationPipelineTrainConfig, nodeClassificationTrainingPipeline.trainingParameterSpace().get(TrainingMethod.RandomForest).isEmpty() ? FeaturesFactory.extractLazyFeatures(graph, nodeClassificationTrainingPipeline.featureProperties()) : FeaturesFactory.extractEagerFeatures(graph, nodeClassificationTrainingPipeline.featureProperties()), extractLabelsAndClassCounts.labels(), LocalIdMap.ofSorted(classCounts.keys()), nodeClassificationPipelineTrainConfig.metrics(classCounts.keys()), classCounts, progressTracker, terminationFlag);
    }

    private NodeClassificationTrain(Graph graph, NodeClassificationTrainingPipeline nodeClassificationTrainingPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, Features features, HugeLongArray hugeLongArray, LocalIdMap localIdMap, List<ClassificationMetric> list, Multiset<Long> multiset, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.graph = graph;
        this.pipeline = nodeClassificationTrainingPipeline;
        this.config = nodeClassificationPipelineTrainConfig;
        this.features = features;
        this.targets = hugeLongArray;
        this.classIdMap = localIdMap;
        this.metrics = list;
        this.classCounts = multiset;
    }

    public NodeClassificationTrainResult compute() {
        this.progressTracker.beginSubTask("Shuffle and split");
        NodePropertyPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        long size = this.features.size();
        HugeLongArray hugeLongArray = this.targets;
        Objects.requireNonNull(hugeLongArray);
        NodeSplitter.NodeSplits split = new NodeSplitter(size, hugeLongArray::get, new TreeSet(this.classCounts.keys()), this.progressTracker).split(splitConfig.testFraction(), splitConfig.validationFolds(), this.config.randomSeed());
        this.progressTracker.endSubTask("Shuffle and split");
        TrainingStatistics trainingStatistics = new TrainingStatistics(List.copyOf(this.metrics));
        selectBestModel(split.innerSplits(), trainingStatistics);
        evaluateBestModel(split.outerSplit(), trainingStatistics);
        return ImmutableNodeClassificationTrainResult.of(createModel(retrainBestModel(split.allTrainingExamples(), trainingStatistics), trainingStatistics), trainingStatistics);
    }

    private void selectBestModel(List<TrainingExamplesSplit> list, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Select best model");
        RandomSearch randomSearch = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.numberOfModelSelectionTrials(), this.config.randomSeed());
        int i = 0;
        while (randomSearch.hasNext()) {
            this.progressTracker.beginSubTask("Trial");
            TrainerConfig next = randomSearch.next();
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Method: %s, Parameters: %s", new Object[]{next.methodName(), next.toMap()}));
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(next, list.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(next, list.size());
            for (TrainingExamplesSplit trainingExamplesSplit : list) {
                HugeLongArray trainSet = trainingExamplesSplit.trainSet();
                HugeLongArray testSet = trainingExamplesSplit.testSet();
                Classifier trainModel = trainModel(trainSet, next, ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(modelStatsBuilder);
                registerMetricScores(testSet, trainModel, (v1, v2) -> {
                    r3.update(v1, v2);
                }, ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(modelStatsBuilder2);
                registerMetricScores(trainSet, trainModel, (v1, v2) -> {
                    r3.update(v1, v2);
                }, ProgressTracker.NULL_TRACKER);
                this.progressTracker.logProgress();
            }
            this.metrics.forEach(classificationMetric -> {
                trainingStatistics.addValidationStats(classificationMetric, modelStatsBuilder.build(classificationMetric));
                trainingStatistics.addTrainStats(classificationMetric, modelStatsBuilder2.build(classificationMetric));
            });
            Map<Metric, Double> findModelValidationAvg = trainingStatistics.findModelValidationAvg(i);
            Map<Metric, Double> findModelTrainAvg = trainingStatistics.findModelTrainAvg(i);
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Main validation metric (%s): %.4f", new Object[]{trainingStatistics.evaluationMetric(), Double.valueOf(trainingStatistics.getMainValidationMetric(i))}));
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Validation metrics: %s", new Object[]{findModelValidationAvg}));
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Training metrics: %s", new Object[]{findModelTrainAvg}));
            i++;
            this.progressTracker.endSubTask("Trial");
        }
        this.progressTracker.logMessage(StringFormatting.formatWithLocale("Best trial was Trial %d with main validation metric %.4f", new Object[]{Integer.valueOf(trainingStatistics.getBestTrialIdx() + 1), Double.valueOf(trainingStatistics.getBestTrialScore())}));
        this.progressTracker.endSubTask("Select best model");
    }

    private void registerMetricScores(HugeLongArray hugeLongArray, Classifier classifier, BiConsumer<Metric, Double> biConsumer, ProgressTracker progressTracker) {
        ClassificationMetricComputer forEvaluationSet = ClassificationMetricComputer.forEvaluationSet(this.features, this.targets, this.classCounts, hugeLongArray, classifier, this.config.concurrency(), this.terminationFlag, progressTracker);
        this.metrics.forEach(classificationMetric -> {
            biConsumer.accept(classificationMetric, Double.valueOf(forEvaluationSet.score(classificationMetric)));
        });
    }

    private void evaluateBestModel(TrainingExamplesSplit trainingExamplesSplit, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Train best model");
        Classifier trainModel = trainModel(trainingExamplesSplit.trainSet(), trainingStatistics.bestParameters(), this.progressTracker);
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Evaluate on test data", trainingExamplesSplit.testSet().size() + trainingExamplesSplit.trainSet().size());
        HugeLongArray testSet = trainingExamplesSplit.testSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(testSet, trainModel, (v1, v2) -> {
            r3.addTestScore(v1, v2);
        }, this.progressTracker);
        HugeLongArray trainSet = trainingExamplesSplit.trainSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(trainSet, trainModel, (v1, v2) -> {
            r3.addOuterTrainScore(v1, v2);
        }, this.progressTracker);
        this.progressTracker.endSubTask("Evaluate on test data");
        this.progressTracker.logMessage(StringFormatting.formatWithLocale("Final model metrics on test set: %s", new Object[]{trainingStatistics.winningModelTestMetrics()}));
    }

    private Classifier retrainBestModel(HugeLongArray hugeLongArray, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Retrain best model");
        Classifier trainModel = trainModel(hugeLongArray, trainingStatistics.bestParameters(), this.progressTracker);
        this.progressTracker.endSubTask("Retrain best model");
        this.progressTracker.logMessage(StringFormatting.formatWithLocale("Final model metrics on full train set: %s", new Object[]{trainingStatistics.winningModelOuterTrainMetrics()}));
        return trainModel;
    }

    private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> createModel(Classifier classifier, TrainingStatistics trainingStatistics) {
        return Model.of(this.config.username(), this.config.modelName(), NodeClassificationTrainingPipeline.MODEL_TYPE, this.graph.schema(), classifier.data(), this.config, NodeClassificationPipelineModelInfo.builder().classes(this.classIdMap.originalIdsList()).bestParameters(trainingStatistics.bestParameters()).metrics(trainingStatistics.metricsForWinningModel()).pipeline(NodeClassificationPredictPipeline.from(this.pipeline)).build());
    }

    private Classifier trainModel(HugeLongArray hugeLongArray, TrainerConfig trainerConfig, ProgressTracker progressTracker) {
        return ClassifierTrainerFactory.create(trainerConfig, this.classIdMap, this.terminationFlag, progressTracker, this.config.concurrency(), this.config.randomSeed(), false).train(this.features, this.targets, ReadOnlyHugeLongArray.of(hugeLongArray));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 102230:
                if (implMethodName.equals("get")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("org/eclipse/collections/api/block/function/primitive/LongToLongFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(J)J") && serializedLambda.getImplClass().equals("org/neo4j/gds/core/utils/paged/HugeLongArray") && serializedLambda.getImplMethodSignature().equals("(J)J")) {
                    HugeLongArray hugeLongArray = (HugeLongArray) serializedLambda.getCapturedArg(0);
                    return hugeLongArray::get;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
