package org.neo4j.gds.ml.pipeline.nodePipeline.train;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import org.eclipse.collections.api.tuple.Pair;
import org.eclipse.collections.impl.tuple.Tuples;
import org.immutables.value.Value;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.NodeProperties;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.gradientdescent.GradientDescentConfig;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.nodemodels.BestMetricData;
import org.neo4j.gds.ml.nodemodels.BestModelStats;
import org.neo4j.gds.ml.nodemodels.ClassificationMetricComputer;
import org.neo4j.gds.ml.nodemodels.ImmutableModelStats;
import org.neo4j.gds.ml.nodemodels.Metric;
import org.neo4j.gds.ml.nodemodels.MetricComputer;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.nodemodels.StatsMap;
import org.neo4j.gds.ml.nodemodels.metrics.MetricSpecification;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationSplitConfig;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.neo4j.gds.ml.util.TrainingSetWarnings;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.Features;
import org.neo4j.gds.models.FeaturesFactory;
import org.neo4j.gds.models.Trainer;
import org.neo4j.gds.models.TrainerConfig;
import org.neo4j.gds.models.TrainerFactory;
import org.neo4j.gds.models.TrainingMethod;
import org.neo4j.gds.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.models.logisticregression.LogisticRegressionTrainer;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/train/NodeClassificationTrain.class */
public final class NodeClassificationTrain {
    private final Graph graph;
    private final NodeClassificationPipelineTrainConfig config;
    private final NodeClassificationPipeline pipeline;
    private final Features features;
    private final HugeLongArray targets;
    private final LocalIdMap classIdMap;
    private final HugeLongArray nodeIds;
    private final List<Metric> metrics;
    private final StatsMap trainStats;
    private final StatsMap validationStats;
    private final MetricComputer metricComputer;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/train/NodeClassificationTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        TrainerConfig bestParameters();

        Map<Metric, List<ModelStats>> trainStats();

        Map<Metric, List<ModelStats>> validationStats();

        static ModelSelectResult of(TrainerConfig trainerConfig, StatsMap statsMap, StatsMap statsMap2) {
            return ImmutableModelSelectResult.of(trainerConfig, (Map<? extends Metric, ? extends List<ModelStats>>) statsMap.getMap(), (Map<? extends Metric, ? extends List<ModelStats>>) statsMap2.getMap());
        }

        @Value.Derived
        default Map<String, Object> toMap() {
            Function function = map -> {
                return (Map) map.entrySet().stream().collect(Collectors.toMap(entry -> {
                    return ((Metric) entry.getKey()).name();
                }, entry2 -> {
                    return ((List) entry2.getValue()).stream().map((v0) -> {
                        return v0.toMap();
                    });
                }));
            };
            return Map.of("bestParameters", bestParameters().toMap(), "trainStats", function.apply(trainStats()), "validationStats", function.apply(validationStats()));
        }
    }

    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/train/NodeClassificationTrain$ModelStatsBuilder.class */
    private static class ModelStatsBuilder {
        private final Map<Metric, Double> min = new HashMap();
        private final Map<Metric, Double> max = new HashMap();
        private final Map<Metric, Double> sum = new HashMap();
        private final TrainerConfig modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(TrainerConfig trainerConfig, int i) {
            this.modelParams = trainerConfig;
            this.numberOfSplits = i;
        }

        void update(Metric metric, double d) {
            this.min.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats build(Metric metric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(metric).doubleValue() / this.numberOfSplits, this.min.get(metric).doubleValue(), this.max.get(metric).doubleValue());
        }
    }

    public static MemoryEstimation estimate(NodeClassificationPipeline nodeClassificationPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig) {
        int i = 1000;
        int i2 = 500;
        double testFraction = nodeClassificationPipeline.splitConfig().testFraction();
        int validationFolds = nodeClassificationPipeline.splitConfig().validationFolds();
        Optional empty = Optional.empty();
        if (!nodeClassificationPipeline.trainingParameterSpace().get(TrainingMethod.LogisticRegression).isEmpty()) {
            int orElseThrow = nodeClassificationPipeline.trainingParameterSpace().get(TrainingMethod.LogisticRegression).stream().mapToInt(trainerConfig -> {
                return ((GradientDescentConfig) trainerConfig).batchSize();
            }).max().orElseThrow();
            empty = Optional.of(MemoryEstimations.maxEstimation(List.of(modelTrainAndEvaluateMemoryUsage(orElseThrow, 1000, 500, j -> {
                return (long) (((j * testFraction) * (validationFolds - 1)) / validationFolds);
            }), MemoryEstimations.delegateEstimation(modelTrainAndEvaluateMemoryUsage(orElseThrow, 1000, 500, j2 -> {
                return (long) (j2 * testFraction);
            }), "best model evaluation"))));
        }
        MemoryEstimations.Builder add = MemoryEstimations.builder().perNode("global targets", HugeLongArray::memoryEstimation).rangePerNode("global class counts", j3 -> {
            return MemoryRange.of(16L, i * 8);
        }).add("metrics", MetricSpecification.memoryEstimation(1000)).perNode("node IDs", HugeLongArray::memoryEstimation).add("outer split", FractionSplitter.estimate(1.0d - testFraction)).add("inner split", StratifiedKFoldSplitter.memoryEstimationForNodeSet(validationFolds, 1.0d - testFraction)).add("stats map train", StatsMap.memoryEstimation(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationPipeline.numberOfModelCandidates())).add("stats map validation", StatsMap.memoryEstimation(nodeClassificationPipelineTrainConfig.metrics().size(), nodeClassificationPipeline.numberOfModelCandidates()));
        if (empty.isPresent()) {
            add.add("max of model selection and best model evaluation", (MemoryEstimation) empty.get());
        } else {
            add.add(MemoryEstimations.of("max of model selection and best model evaluation (unknown)", MemoryRange.of(0L)));
        }
        if (!nodeClassificationPipeline.trainingParameterSpace().get(TrainingMethod.RandomForest).isEmpty()) {
            add.perGraphDimension("cached feature vectors", (graphDimensions, num) -> {
                return MemoryRange.of(HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(10L)), HugeObjectArray.memoryEstimation(graphDimensions.nodeCount(), MemoryUsage.sizeOfDoubleArray(i2)));
            });
        }
        return add.build();
    }

    public static String taskName() {
        return "NCTrain";
    }

    public static Task progressTask(int i, int i2) {
        return Tasks.task(taskName(), Tasks.leaf("ShuffleAndSplit"), new Task[]{Tasks.iterativeFixed("SelectBestModel", () -> {
            return List.of(Tasks.iterativeFixed("Model Candidate", () -> {
                return List.of(Tasks.task("Split", Trainer.progressTask("Training"), new Task[]{Tasks.leaf("Evaluate")}));
            }, i));
        }, i2), Trainer.progressTask("TrainSelectedOnRemainder"), Tasks.leaf("EvaluateSelectedModel"), Trainer.progressTask("RetrainSelectedModel")});
    }

    @NotNull
    private static MemoryEstimation modelTrainAndEvaluateMemoryUsage(int i, int i2, int i3, LongUnaryOperator longUnaryOperator) {
        return MemoryEstimations.builder("model selection").max(List.of(LogisticRegressionTrainer.memoryEstimation(i2, i3, i), MemoryEstimations.builder("computing metrics").perNode("local targets", j -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator.applyAsLong(j));
        }).perNode("predicted classes", j2 -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator.applyAsLong(j2));
        }).fixed("probabilities", MemoryUsage.sizeOfDoubleArray(i2)).fixed("computation graph", LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(100, i3, i2)).build())).build();
    }

    public static NodeClassificationTrain create(Graph graph, NodeClassificationPipeline nodeClassificationPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, ProgressTracker progressTracker) {
        Pair<HugeLongArray, Multiset<Long>> computeGlobalTargetsAndClasses = computeGlobalTargetsAndClasses(graph.nodeProperties(nodeClassificationPipelineTrainConfig.targetProperty()), graph.nodeCount());
        HugeLongArray hugeLongArray = (HugeLongArray) computeGlobalTargetsAndClasses.getOne();
        LocalIdMap makeClassIdMap = makeClassIdMap(hugeLongArray);
        Multiset multiset = (Multiset) computeGlobalTargetsAndClasses.getTwo();
        List<Metric> createMetrics = createMetrics(nodeClassificationPipelineTrainConfig, multiset);
        HugeLongArray newArray = HugeLongArray.newArray(graph.nodeCount());
        newArray.setAll(j -> {
            return j;
        });
        return new NodeClassificationTrain(graph, nodeClassificationPipeline, nodeClassificationPipelineTrainConfig, nodeClassificationPipeline.trainingParameterSpace().get(TrainingMethod.RandomForest).isEmpty() ? FeaturesFactory.extractLazyFeatures(graph, nodeClassificationPipeline.featureProperties()) : FeaturesFactory.extractEagerFeatures(graph, nodeClassificationPipeline.featureProperties()), hugeLongArray, makeClassIdMap, multiset, createMetrics, newArray, StatsMap.create(createMetrics), StatsMap.create(createMetrics), progressTracker);
    }

    public static LocalIdMap makeClassIdMap(HugeLongArray hugeLongArray) {
        TreeSet treeSet = new TreeSet();
        LocalIdMap localIdMap = new LocalIdMap();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= hugeLongArray.size()) {
                Objects.requireNonNull(localIdMap);
                treeSet.forEach((v1) -> {
                    r1.toMapped(v1);
                });
                return localIdMap;
            }
            treeSet.add(Long.valueOf(hugeLongArray.get(j2)));
            j = j2 + 1;
        }
    }

    private static Pair<HugeLongArray, Multiset<Long>> computeGlobalTargetsAndClasses(NodeProperties nodeProperties, long j) {
        Multiset multiset = new Multiset();
        HugeLongArray newArray = HugeLongArray.newArray(j);
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return Tuples.pair(newArray, multiset);
            }
            newArray.set(j3, nodeProperties.longValue(j3));
            multiset.add(Long.valueOf(nodeProperties.longValue(j3)));
            j2 = j3 + 1;
        }
    }

    private static List<Metric> createMetrics(NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, Multiset<Long> multiset) {
        return (List) nodeClassificationPipelineTrainConfig.metrics().stream().flatMap(metricSpecification -> {
            return metricSpecification.createMetrics(multiset.keys());
        }).collect(Collectors.toList());
    }

    private NodeClassificationTrain(Graph graph, NodeClassificationPipeline nodeClassificationPipeline, NodeClassificationPipelineTrainConfig nodeClassificationPipelineTrainConfig, Features features, HugeLongArray hugeLongArray, LocalIdMap localIdMap, Multiset<Long> multiset, List<Metric> list, HugeLongArray hugeLongArray2, StatsMap statsMap, StatsMap statsMap2, ProgressTracker progressTracker) {
        this.progressTracker = progressTracker;
        this.graph = graph;
        this.pipeline = nodeClassificationPipeline;
        this.config = nodeClassificationPipelineTrainConfig;
        this.features = features;
        this.targets = hugeLongArray;
        this.classIdMap = localIdMap;
        this.metrics = list;
        this.nodeIds = hugeLongArray2;
        this.trainStats = statsMap;
        this.validationStats = statsMap2;
        this.metricComputer = new ClassificationMetricComputer(list, multiset, features, hugeLongArray, nodeClassificationPipelineTrainConfig.concurrency(), progressTracker, this.terminationFlag);
    }

    public NodeClassificationTrainResult compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.beginSubTask();
        ShuffleUtil.shuffleHugeLongArray(this.nodeIds, ShuffleUtil.createRandomDataGenerator(this.config.randomSeed()));
        NodeClassificationSplitConfig splitConfig = this.pipeline.splitConfig();
        TrainingExamplesSplit split = new FractionSplitter().split(this.nodeIds, 1.0d - splitConfig.testFraction());
        List<TrainingExamplesSplit> splits = new StratifiedKFoldSplitter(splitConfig.validationFolds(), ReadOnlyHugeLongArray.of(split.trainSet()), ReadOnlyHugeLongArray.of(this.targets), this.config.randomSeed()).splits();
        TrainingSetWarnings.warnForSmallNodeSets(split.trainSet().size(), split.testSet().size(), splitConfig.validationFolds(), this.progressTracker);
        this.progressTracker.endSubTask();
        ModelSelectResult selectBestModel = selectBestModel(splits);
        TrainerConfig bestParameters = selectBestModel.bestParameters();
        Map<Metric, BestMetricData> evaluateBestModel = evaluateBestModel(split, selectBestModel, bestParameters);
        Classifier retrainBestModel = retrainBestModel(bestParameters);
        this.progressTracker.endSubTask();
        return ImmutableNodeClassificationTrainResult.of(createModel(retrainBestModel, selectBestModel, evaluateBestModel), selectBestModel);
    }

    private ModelSelectResult selectBestModel(List<TrainingExamplesSplit> list) {
        this.progressTracker.beginSubTask();
        this.pipeline.trainingParameterSpace().values().stream().flatMap((v0) -> {
            return v0.stream();
        }).forEach(trainerConfig -> {
            this.progressTracker.beginSubTask();
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(trainerConfig, list.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(trainerConfig, list.size());
            Iterator it = list.iterator();
            while (it.hasNext()) {
                TrainingExamplesSplit trainingExamplesSplit = (TrainingExamplesSplit) it.next();
                this.progressTracker.beginSubTask();
                HugeLongArray trainSet = trainingExamplesSplit.trainSet();
                HugeLongArray testSet = trainingExamplesSplit.testSet();
                this.progressTracker.beginSubTask("Training");
                Classifier trainModel = trainModel(trainSet, trainerConfig);
                this.progressTracker.endSubTask("Training");
                this.progressTracker.beginSubTask(testSet.size() + trainSet.size());
                Map computeMetrics = this.metricComputer.computeMetrics(testSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetrics.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map computeMetrics2 = this.metricComputer.computeMetrics(trainSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetrics2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                this.progressTracker.endSubTask();
                this.progressTracker.endSubTask();
            }
            this.progressTracker.endSubTask();
            this.metrics.forEach(metric -> {
                this.validationStats.add(metric, modelStatsBuilder.build(metric));
                this.trainStats.add(metric, modelStatsBuilder2.build(metric));
            });
        });
        this.progressTracker.endSubTask();
        return ModelSelectResult.of(this.validationStats.pickBestModelStats(this.metrics.get(0)).params(), this.trainStats, this.validationStats);
    }

    private Map<Metric, BestMetricData> evaluateBestModel(TrainingExamplesSplit trainingExamplesSplit, ModelSelectResult modelSelectResult, TrainerConfig trainerConfig) {
        this.progressTracker.beginSubTask("TrainSelectedOnRemainder");
        Classifier trainModel = trainModel(trainingExamplesSplit.trainSet(), trainerConfig);
        this.progressTracker.endSubTask("TrainSelectedOnRemainder");
        this.progressTracker.beginSubTask(trainingExamplesSplit.testSet().size() + trainingExamplesSplit.trainSet().size());
        Map computeMetrics = this.metricComputer.computeMetrics(trainingExamplesSplit.testSet(), trainModel);
        Map computeMetrics2 = this.metricComputer.computeMetrics(trainingExamplesSplit.trainSet(), trainModel);
        this.progressTracker.endSubTask();
        return mergeMetricResults(modelSelectResult, computeMetrics2, computeMetrics);
    }

    private Classifier retrainBestModel(TrainerConfig trainerConfig) {
        this.progressTracker.beginSubTask("RetrainSelectedModel");
        Classifier trainModel = trainModel(this.nodeIds, trainerConfig);
        this.progressTracker.endSubTask("RetrainSelectedModel");
        return trainModel;
    }

    private Model<Classifier.ClassifierData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> createModel(Classifier classifier, ModelSelectResult modelSelectResult, Map<Metric, BestMetricData> map) {
        return Model.of(this.config.username(), this.config.modelName(), NodeClassificationPipeline.MODEL_TYPE, this.graph.schema(), classifier.data(), this.config, NodeClassificationPipelineModelInfo.builder().classes(this.classIdMap.originalIdsList()).bestParameters(modelSelectResult.bestParameters()).metrics(map).trainingPipeline(this.pipeline.copy()).build());
    }

    private static Map<Metric, BestMetricData> mergeMetricResults(ModelSelectResult modelSelectResult, Map<Metric, Double> map, Map<Metric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), metric -> {
            return BestMetricData.of(BestModelStats.findBestModelStats(modelSelectResult.trainStats().get(metric), modelSelectResult.bestParameters()), BestModelStats.findBestModelStats(modelSelectResult.validationStats().get(metric), modelSelectResult.bestParameters()), ((Double) map.get(metric)).doubleValue(), ((Double) map2.get(metric)).doubleValue());
        }));
    }

    private Classifier trainModel(HugeLongArray hugeLongArray, TrainerConfig trainerConfig) {
        return TrainerFactory.create(trainerConfig, this.classIdMap, this.terminationFlag, this.progressTracker, this.config.concurrency(), this.config.randomSeed(), false).train(this.features, this.targets, ReadOnlyHugeLongArray.of(hugeLongArray));
    }
}
