/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.regression.configure;

import java.util.Collection;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.core.ConfigKeyValidation;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainerConfig;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.node.NodePipelineInfoResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainingPipeline;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class NodeRegressionPipelineAddTrainerMethodProcs
extends BaseProc {
    @Procedure(name="gds.alpha.pipeline.nodeRegression.addLinearRegression", mode=Mode.READ)
    @Description(value="Add a linear regression model candidate to a node regression pipeline.")
    public Stream<NodePipelineInfoResult> addLogisticRegression(@Name(value="pipelineName") String pipelineName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        NodeRegressionTrainingPipeline pipeline = (NodeRegressionTrainingPipeline)PipelineCatalog.getTyped((String)this.username(), (String)pipelineName, NodeRegressionTrainingPipeline.class);
        Collection allowedKeys = LinearRegressionTrainConfig.DEFAULT.configKeys();
        ConfigKeyValidation.requireOnlyKeysFrom((Collection)allowedKeys, configuration.keySet());
        pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, (TrainingMethod)TrainingMethod.LinearRegression));
        return Stream.of(new NodePipelineInfoResult(pipelineName, (NodePropertyTrainingPipeline)pipeline));
    }

    @Procedure(name="gds.alpha.pipeline.nodeRegression.addRandomForest", mode=Mode.READ)
    @Description(value="Add a random forest model candidate to a node regression pipeline.")
    public Stream<NodePipelineInfoResult> addRandomForest(@Name(value="pipelineName") String pipelineName, @Name(value="configuration") Map<String, Object> configuration) {
        NodeRegressionTrainingPipeline pipeline = (NodeRegressionTrainingPipeline)PipelineCatalog.getTyped((String)this.username(), (String)pipelineName, NodeRegressionTrainingPipeline.class);
        Collection allowedKeys = RandomForestRegressorTrainerConfig.DEFAULT.configKeys();
        ConfigKeyValidation.requireOnlyKeysFrom((Collection)allowedKeys, configuration.keySet());
        pipeline.addTrainerConfig(TunableTrainerConfig.of(configuration, (TrainingMethod)TrainingMethod.RandomForestRegression));
        return Stream.of(new NodePipelineInfoResult(pipelineName, (NodePropertyTrainingPipeline)pipeline));
    }
}

