package org.neo4j.gds.ml.pipeline;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.GraphStoreValidation;
import org.neo4j.gds.ml.pipeline.Pipeline;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/PipelineExecutor.class */
public abstract class PipelineExecutor<PIPELINE_CONFIG extends AlgoBaseConfig, PIPELINE extends Pipeline<?>, RESULT> extends Algorithm<RESULT> {
    protected final PIPELINE pipeline;
    protected final PIPELINE_CONFIG config;
    protected final ExecutionContext executionContext;
    protected final GraphStore graphStore;
    protected final String graphName;

    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/PipelineExecutor$DatasetSplits.class */
    public enum DatasetSplits {
        TRAIN,
        TEST,
        TEST_COMPLEMENT,
        FEATURE_INPUT
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/PipelineExecutor$GraphFilter.class */
    public interface GraphFilter {
        Collection<NodeLabel> nodeLabels();

        Collection<RelationshipType> relationshipTypes();
    }

    protected PipelineExecutor(PIPELINE pipeline, PIPELINE_CONFIG pipeline_config, ExecutionContext executionContext, GraphStore graphStore, String str, ProgressTracker progressTracker) {
        super(progressTracker);
        this.pipeline = pipeline;
        this.config = pipeline_config;
        this.executionContext = executionContext;
        this.graphStore = graphStore;
        this.graphName = str;
    }

    public static MemoryEstimation estimateNodePropertySteps(ModelCatalog modelCatalog, List<ExecutableNodePropertyStep> list, List<String> list2, List<String> list3) {
        return MemoryEstimations.maxEstimation("NodeProperty Steps", (List) list.stream().map(executableNodePropertyStep -> {
            return executableNodePropertyStep.estimate(modelCatalog, list2, list3);
        }).collect(Collectors.toList()));
    }

    public static void validateTrainingParameterSpace(Pipeline pipeline) {
        if (pipeline.numberOfModelCandidates() == 0) {
            throw new IllegalArgumentException("Need at least one model candidate for training.");
        }
    }

    public abstract Map<DatasetSplits, GraphFilter> splitDataset();

    protected abstract RESULT execute(Map<DatasetSplits, GraphFilter> map);

    public RESULT compute() {
        this.progressTracker.beginSubTask();
        Map<DatasetSplits, GraphFilter> splitDataset = splitDataset();
        try {
            this.progressTracker.beginSubTask("execute node property steps");
            executeNodePropertySteps(splitDataset.get(DatasetSplits.FEATURE_INPUT));
            this.progressTracker.endSubTask("execute node property steps");
            validate(this.graphStore, this.config);
            RESULT execute = execute(splitDataset);
            this.progressTracker.endSubTask();
            cleanUpGraphStore(splitDataset);
            return execute;
        } catch (Throwable th) {
            cleanUpGraphStore(splitDataset);
            throw th;
        }
    }

    protected void validate(GraphStore graphStore, PIPELINE_CONFIG pipeline_config) {
        this.pipeline.validateFeatureProperties(graphStore, pipeline_config);
        GraphStoreValidation.validate(graphStore, pipeline_config);
    }

    public void release() {
    }

    private void executeNodePropertySteps(GraphFilter graphFilter) {
        for (ExecutableNodePropertyStep executableNodePropertyStep : this.pipeline.nodePropertySteps()) {
            this.progressTracker.beginSubTask();
            executableNodePropertyStep.execute(this.executionContext, this.graphName, graphFilter.nodeLabels(), graphFilter.relationshipTypes());
            this.progressTracker.endSubTask();
        }
    }

    protected void cleanUpGraphStore(Map<DatasetSplits, GraphFilter> map) {
        removeNodeProperties(this.graphStore, this.config.nodeLabelIdentifiers(this.graphStore));
    }

    private void removeNodeProperties(GraphStore graphStore, Iterable<NodeLabel> iterable) {
        this.pipeline.nodePropertySteps().forEach(executableNodePropertyStep -> {
            Object obj = executableNodePropertyStep.config().get("mutateProperty");
            if (obj instanceof String) {
                iterable.forEach(nodeLabel -> {
                    graphStore.removeNodeProperty(nodeLabel, (String) obj);
                });
            }
        });
    }
}
