/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline;

import java.util.Collection;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.ConfigKeyValidation;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineInfoResult;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Internal;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class LinkPredictionPipelineAddTrainerMethodProcs
extends BaseProc {
    @Procedure(name="gds.beta.pipeline.linkPrediction.addLogisticRegression", mode=Mode.READ)
    @Description(value="Add a logistic regression configuration to the parameter space of the link prediction train pipeline.")
    public Stream<PipelineInfoResult> addLogisticRegression(@Name(value="pipelineName") String pipelineName, @Name(value="config", defaultValue="{}") Map<String, Object> logisticRegressionClassifierConfig) {
        LinkPredictionTrainingPipeline pipeline = (LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)this.username(), (String)pipelineName, LinkPredictionTrainingPipeline.class);
        Collection allowedKeys = LogisticRegressionTrainConfig.DEFAULT.configKeys();
        ConfigKeyValidation.requireOnlyKeysFrom((Collection)allowedKeys, logisticRegressionClassifierConfig.keySet());
        TunableTrainerConfig tunableTrainerConfig = TunableTrainerConfig.of(logisticRegressionClassifierConfig, (TrainingMethod)TrainingMethod.LogisticRegression);
        pipeline.addTrainerConfig(tunableTrainerConfig);
        return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
    }

    @Procedure(name="gds.beta.pipeline.linkPrediction.addRandomForest", mode=Mode.READ)
    @Description(value="Add a random forest configuration to the parameter space of the link prediction train pipeline.")
    public Stream<PipelineInfoResult> addRandomForest(@Name(value="pipelineName") String pipelineName, @Name(value="config") Map<String, Object> randomForestClassifierConfig) {
        LinkPredictionTrainingPipeline pipeline = (LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)this.username(), (String)pipelineName, LinkPredictionTrainingPipeline.class);
        Collection allowedKeys = RandomForestClassifierTrainerConfig.DEFAULT.configKeys();
        ConfigKeyValidation.requireOnlyKeysFrom((Collection)allowedKeys, randomForestClassifierConfig.keySet());
        TunableTrainerConfig tunableTrainerConfig = TunableTrainerConfig.of(randomForestClassifierConfig, (TrainingMethod)TrainingMethod.RandomForestClassification);
        pipeline.addTrainerConfig(tunableTrainerConfig);
        return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
    }

    @Procedure(name="gds.alpha.pipeline.linkPrediction.addRandomForest", mode=Mode.READ, deprecatedBy="gds.beta.pipeline.linkPrediction.addRandomForest")
    @Description(value="Add a random forest configuration to the parameter space of the link prediction train pipeline.")
    @Internal
    @Deprecated(forRemoval=true)
    public Stream<PipelineInfoResult> addRandomForestAlpha(@Name(value="pipelineName") String pipelineName, @Name(value="config") Map<String, Object> randomForestClassifierConfig) {
        this.executionContext().metricsFacade().deprecatedProcedures().called("gds.alpha.pipeline.linkPrediction.addRandomForest");
        this.executionContext().log().warn("Procedure `gds.alpha.pipeline.linkPrediction.addRandomForest` has been deprecated, please use `gds.beta.pipeline.linkPrediction.addRandomForest`.", new Object[0]);
        return this.addRandomForest(pipelineName, randomForestClassifierConfig);
    }

    @Procedure(name="gds.alpha.pipeline.linkPrediction.addMLP", mode=Mode.READ)
    @Description(value="Add a multilayer perceptron configuration to the parameter space of the link prediction train pipeline.")
    public Stream<PipelineInfoResult> addMLP(@Name(value="pipelineName") String pipelineName, @Name(value="config", defaultValue="{}") Map<String, Object> mlpClassifierConfig) {
        LinkPredictionTrainingPipeline pipeline = (LinkPredictionTrainingPipeline)PipelineCatalog.getTyped((String)this.username(), (String)pipelineName, LinkPredictionTrainingPipeline.class);
        Collection allowedKeys = MLPClassifierTrainConfig.DEFAULT.configKeys();
        ConfigKeyValidation.requireOnlyKeysFrom((Collection)allowedKeys, mlpClassifierConfig.keySet());
        pipeline.addTrainerConfig(TunableTrainerConfig.of(mlpClassifierConfig, (TrainingMethod)TrainingMethod.MLPClassification));
        return Stream.of(new PipelineInfoResult(pipelineName, pipeline));
    }
}

