package org.neo4j.gds.ml.pipeline.node.classification;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.ConfigKeyValidation;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.procedures.pipelines.NodePipelineInfoResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/NodeClassificationPipelineAddTrainerMethodProcs.class */
public class NodeClassificationPipelineAddTrainerMethodProcs extends BaseProc {
    @Procedure(name = "gds.beta.pipeline.nodeClassification.addLogisticRegression", mode = Mode.READ)
    @Description("Add a logistic regression configuration to the parameter space of the node classification train pipeline.")
    public Stream<NodePipelineInfoResult> addLogisticRegression(@Name("pipelineName") String str, @Name(value = "config", defaultValue = "{}") Map<String, Object> map) {
        NodeClassificationTrainingPipeline typed = PipelineCatalog.getTyped(username(), str, NodeClassificationTrainingPipeline.class);
        ConfigKeyValidation.requireOnlyKeysFrom(LogisticRegressionTrainConfig.DEFAULT.configKeys(), map.keySet());
        typed.addTrainerConfig(TunableTrainerConfig.of(map, TrainingMethod.LogisticRegression));
        return Stream.of(new NodePipelineInfoResult(str, typed));
    }

    @Procedure(name = "gds.beta.pipeline.nodeClassification.addRandomForest", mode = Mode.READ)
    @Description("Add a random forest configuration to the parameter space of the node classification train pipeline.")
    public Stream<NodePipelineInfoResult> addRandomForest(@Name("pipelineName") String str, @Name("config") Map<String, Object> map) {
        NodeClassificationTrainingPipeline typed = PipelineCatalog.getTyped(username(), str, NodeClassificationTrainingPipeline.class);
        ConfigKeyValidation.requireOnlyKeysFrom(RandomForestClassifierTrainerConfig.DEFAULT.configKeys(), map.keySet());
        typed.addTrainerConfig(TunableTrainerConfig.of(map, TrainingMethod.RandomForestClassification));
        return Stream.of(new NodePipelineInfoResult(str, typed));
    }

    @Internal
    @Description("Add a random forest configuration to the parameter space of the node classification train pipeline.")
    @Deprecated(forRemoval = true)
    @Procedure(name = "gds.alpha.pipeline.nodeClassification.addRandomForest", mode = Mode.READ, deprecatedBy = "gds.beta.pipeline.nodeClassification.addRandomForest")
    public Stream<NodePipelineInfoResult> addRandomForestAlpha(@Name("pipelineName") String str, @Name("config") Map<String, Object> map) {
        executionContext().metricsFacade().deprecatedProcedures().called("gds.alpha.pipeline.nodeClassification.addRandomForest");
        executionContext().log().warn("Procedure `gds.alpha.pipeline.nodeClassification.addRandomForest` has been deprecated, please use `gds.beta.pipeline.nodeClassification.addRandomForest`.");
        return addRandomForest(str, map);
    }

    @Procedure(name = "gds.alpha.pipeline.nodeClassification.addMLP", mode = Mode.READ)
    @Description("Add a multilayer perceptron configuration to the parameter space of the node classification train pipeline.")
    public Stream<NodePipelineInfoResult> addMLP(@Name("pipelineName") String str, @Name(value = "config", defaultValue = "{}") Map<String, Object> map) {
        NodeClassificationTrainingPipeline typed = PipelineCatalog.getTyped(username(), str, NodeClassificationTrainingPipeline.class);
        ConfigKeyValidation.requireOnlyKeysFrom(MLPClassifierTrainConfig.DEFAULT.configKeys(), map.keySet());
        typed.addTrainerConfig(TunableTrainerConfig.of(map, TrainingMethod.MLPClassification));
        return Stream.of(new NodePipelineInfoResult(str, typed));
    }
}
