package org.neo4j.gds.ml.pipeline.nodePipeline.regression;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.TreeSet;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.RegressionTrainerFactory;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.PipelineTrainer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyTrainingPipeline;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrain.class */
public final class NodeRegressionTrain implements PipelineTrainer<NodeRegressionTrainResult> {
    private final HugeDoubleArray targets;
    private final IdMap nodeIdMap;
    private final NodeRegressionTrainingPipeline pipeline;
    private final NodeRegressionPipelineTrainConfig trainConfig;
    private final NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer;
    private final ProgressTracker progressTracker;
    private TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

    public static Task progressTask(NodePropertyTrainingPipeline nodePropertyTrainingPipeline, long j) {
        NodePropertyPredictionSplitConfig splitConfig = nodePropertyTrainingPipeline.splitConfig();
        long trainSetSize = splitConfig.trainSetSize(j);
        long testSetSize = splitConfig.testSetSize(j);
        int validationFolds = splitConfig.validationFolds();
        ArrayList arrayList = new ArrayList();
        arrayList.add(NodePropertyStepExecutor.tasks(nodePropertyTrainingPipeline.nodePropertySteps(), j));
        arrayList.addAll(CrossValidation.progressTasks(validationFolds, nodePropertyTrainingPipeline.numberOfModelSelectionTrials(), trainSetSize));
        arrayList.add(ClassifierTrainer.progressTask("Train best model", 5 * trainSetSize));
        arrayList.add(Tasks.leaf("Evaluate on test data", testSetSize));
        arrayList.add(ClassifierTrainer.progressTask("Retrain best model", 5 * j));
        return Tasks.task("Node Regression Train Pipeline", arrayList);
    }

    public static NodeRegressionTrain create(GraphStore graphStore, NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig, NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        Graph graph = graphStore.getGraph(nodeRegressionPipelineTrainConfig.nodeLabelIdentifiers(graphStore));
        nodeRegressionTrainingPipeline.splitConfig().validateMinNumNodesInSplitSets(graph);
        NodePropertyValues nodeProperties = graph.nodeProperties(nodeRegressionPipelineTrainConfig.targetProperty());
        HugeDoubleArray newArray = HugeDoubleArray.newArray(graph.nodeCount());
        Objects.requireNonNull(nodeProperties);
        newArray.setAll(nodeProperties::doubleValue);
        return new NodeRegressionTrain(nodeRegressionTrainingPipeline, nodeRegressionPipelineTrainConfig, nodeFeatureProducer, newArray, graph, progressTracker);
    }

    private NodeRegressionTrain(NodeRegressionTrainingPipeline nodeRegressionTrainingPipeline, NodeRegressionPipelineTrainConfig nodeRegressionPipelineTrainConfig, NodeFeatureProducer<NodeRegressionPipelineTrainConfig> nodeFeatureProducer, HugeDoubleArray hugeDoubleArray, IdMap idMap, ProgressTracker progressTracker) {
        this.pipeline = nodeRegressionTrainingPipeline;
        this.trainConfig = nodeRegressionPipelineTrainConfig;
        this.nodeFeatureProducer = nodeFeatureProducer;
        this.nodeIdMap = idMap;
        this.progressTracker = progressTracker;
        this.targets = hugeDoubleArray;
    }

    @Override // org.neo4j.gds.ml.pipeline.PipelineTrainer
    public void setTerminationFlag(TerminationFlag terminationFlag) {
        this.terminationFlag = terminationFlag;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.pipeline.PipelineTrainer
    public NodeRegressionTrainResult run() {
        this.progressTracker.beginSubTask();
        NodePropertyPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        int concurrency = this.trainConfig.concurrency();
        long nodeCount = this.nodeIdMap.nodeCount();
        ProgressTracker progressTracker = this.progressTracker;
        IdMap idMap = this.nodeIdMap;
        Objects.requireNonNull(idMap);
        LongUnaryOperator longUnaryOperator = idMap::toOriginalNodeId;
        IdMap idMap2 = this.nodeIdMap;
        Objects.requireNonNull(idMap2);
        NodeSplitter.NodeSplits split = new NodeSplitter(concurrency, nodeCount, progressTracker, longUnaryOperator, idMap2::toMappedNodeId).split(splitConfig.testFraction(), splitConfig.validationFolds(), this.trainConfig.randomSeed());
        this.terminationFlag.assertRunning();
        List<Metric> copyOf = List.copyOf(this.trainConfig.metrics());
        TrainingStatistics trainingStatistics = new TrainingStatistics(copyOf);
        Features procedureFeatures = this.nodeFeatureProducer.procedureFeatures(this.pipeline);
        findBestModelCandidate(split.outerSplit().trainSet(), copyOf, procedureFeatures, trainingStatistics);
        evaluateBestModel(split.outerSplit(), procedureFeatures, trainingStatistics);
        Regressor retrainBestModel = retrainBestModel(split.allTrainingExamples(), procedureFeatures, trainingStatistics.bestParameters());
        this.progressTracker.endSubTask();
        return ImmutableNodeRegressionTrainResult.of(retrainBestModel, trainingStatistics);
    }

    private void findBestModelCandidate(ReadOnlyHugeLongArray readOnlyHugeLongArray, List<Metric> list, Features features, TrainingStatistics trainingStatistics) {
        new CrossValidation(this.progressTracker, this.terminationFlag, list, this.pipeline.splitConfig().validationFolds(), this.trainConfig.randomSeed(), (readOnlyHugeLongArray2, trainerConfig, modelSpecificMetricsHandler, logLevel) -> {
            return trainModel(readOnlyHugeLongArray2, trainerConfig, features, logLevel);
        }, (readOnlyHugeLongArray3, regressor, metricConsumer) -> {
            registerMetricScores(readOnlyHugeLongArray3, regressor, features, metricConsumer);
        }).selectModel(readOnlyHugeLongArray, j -> {
            return 0L;
        }, new TreeSet(List.of(0L)), trainingStatistics, new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.numberOfModelSelectionTrials(), this.trainConfig.randomSeed()));
    }

    private void registerMetricScores(ReadOnlyHugeLongArray readOnlyHugeLongArray, Regressor regressor, Features features, MetricConsumer metricConsumer) {
        HugeDoubleArray newArray = HugeDoubleArray.newArray(readOnlyHugeLongArray.size());
        ParallelUtil.parallelForEachNode(readOnlyHugeLongArray.size(), this.trainConfig.concurrency(), j -> {
            newArray.set(j, regressor.predict(features.get(readOnlyHugeLongArray.get(j))));
        });
        this.terminationFlag.assertRunning();
        HugeDoubleArray newArray2 = HugeDoubleArray.newArray(readOnlyHugeLongArray.size());
        ParallelUtil.parallelForEachNode(readOnlyHugeLongArray.size(), this.trainConfig.concurrency(), j2 -> {
            newArray2.set(j2, this.targets.get(readOnlyHugeLongArray.get(j2)));
        });
        this.trainConfig.metrics().forEach(regressionMetrics -> {
            metricConsumer.consume(regressionMetrics, regressionMetrics.compute(newArray2, newArray));
        });
    }

    private void evaluateBestModel(TrainingExamplesSplit trainingExamplesSplit, Features features, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Train best model");
        Regressor trainModel = trainModel(trainingExamplesSplit.trainSet(), trainingStatistics.bestParameters(), features, LogLevel.INFO);
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Evaluate on test data");
        ReadOnlyHugeLongArray trainSet = trainingExamplesSplit.trainSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(trainSet, trainModel, features, trainingStatistics::addOuterTrainScore);
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on full train set: %s", new Object[]{trainingStatistics.winningModelOuterTrainMetrics()}));
        ReadOnlyHugeLongArray testSet = trainingExamplesSplit.testSet();
        Objects.requireNonNull(trainingStatistics);
        registerMetricScores(testSet, trainModel, features, trainingStatistics::addTestScore);
        this.progressTracker.logInfo(StringFormatting.formatWithLocale("Final model metrics on test set: %s", new Object[]{trainingStatistics.winningModelTestMetrics()}));
        this.progressTracker.endSubTask("Evaluate on test data");
    }

    private Regressor retrainBestModel(ReadOnlyHugeLongArray readOnlyHugeLongArray, Features features, TrainerConfig trainerConfig) {
        this.progressTracker.beginSubTask("Retrain best model");
        Regressor trainModel = trainModel(readOnlyHugeLongArray, trainerConfig, features, LogLevel.INFO);
        this.progressTracker.endSubTask("Retrain best model");
        return trainModel;
    }

    private Regressor trainModel(ReadOnlyHugeLongArray readOnlyHugeLongArray, TrainerConfig trainerConfig, Features features, LogLevel logLevel) {
        return RegressionTrainerFactory.create(trainerConfig, this.terminationFlag, this.progressTracker, logLevel, this.trainConfig.concurrency(), this.trainConfig.randomSeed()).train(features, this.targets, readOnlyHugeLongArray);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -955361693:
                if (implMethodName.equals("lambda$findBestModelCandidate$8f8d3d83$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/eclipse/collections/api/block/function/primitive/LongToLongFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(J)J") && serializedLambda.getImplClass().equals("org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrain") && serializedLambda.getImplMethodSignature().equals("(J)J")) {
                    return j -> {
                        return 0L;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
