/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.regression.predict;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.BaseModelData;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPipelineBaseTrainConfig;

final class NodeRegressionPipelineCompanion {
    private NodeRegressionPipelineCompanion() {
    }

    @NotNull
    static Map<String, Object> enhanceUserInput(Map<String, Object> userInput, ExecutionContext executionContext) {
        return Optional.ofNullable(userInput.get("modelName")).map(modelName -> {
            ModelCatalog modelCatalog = executionContext.modelCatalog();
            assert (modelCatalog != null) : "ModelCatalog should have been set in the ExecutionContext by this point!!!";
            Model trainedModel = modelCatalog.get(executionContext.username(), (String)modelName, BaseModelData.class, NodePropertyPipelineBaseTrainConfig.class, Model.CustomInfo.class);
            List combinedTargetNodeLabels = Optional.ofNullable(userInput.get("targetNodeLabels")).map(targetNodeLabels -> (List)targetNodeLabels).orElseGet(() -> ((NodePropertyPipelineBaseTrainConfig)trainedModel.trainConfig()).targetNodeLabels());
            List combinedRelationshipTypes = Optional.ofNullable(userInput.get("relationshipTypes")).map(relationshipTypes -> (List)relationshipTypes).orElseGet(() -> ((NodePropertyPipelineBaseTrainConfig)trainedModel.trainConfig()).relationshipTypes());
            userInput.put("targetNodeLabels", combinedTargetNodeLabels);
            userInput.put("relationshipTypes", combinedRelationshipTypes);
            return userInput;
        }).orElse(userInput);
    }
}

