package org.neo4j.gds.ml.pipeline.node.regression.predict;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionData;
import org.neo4j.gds.ml.models.linearregression.LinearRegressor;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressor;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorData;
import org.neo4j.gds.ml.pipeline.node.regression.predict.NodeRegressionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory.class */
public class NodeRegressionPredictPipelineAlgorithmFactory<CONFIG extends NodeRegressionPredictPipelineBaseConfig> extends GraphStoreAlgorithmFactory<NodeRegressionPredictPipelineExecutor, CONFIG> {
    private final ModelCatalog modelCatalog;
    private final ExecutionContext executionContext;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.neo4j.gds.ml.pipeline.node.regression.predict.NodeRegressionPredictPipelineAlgorithmFactory$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/regression/predict/NodeRegressionPredictPipelineAlgorithmFactory$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$gds$ml$models$TrainingMethod = new int[TrainingMethod.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$gds$ml$models$TrainingMethod[TrainingMethod.LinearRegression.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$gds$ml$models$TrainingMethod[TrainingMethod.RandomForestRegression.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public NodeRegressionPredictPipelineAlgorithmFactory(ExecutionContext executionContext, ModelCatalog modelCatalog) {
        this.modelCatalog = modelCatalog;
        this.executionContext = executionContext;
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        return NodeRegressionPredictPipelineExecutor.progressTask(taskName(), getTrainedNRPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo().pipeline(), graphStore);
    }

    public String taskName() {
        return "Node Classification Predict Pipeline";
    }

    public NodeRegressionPredictPipelineExecutor build(GraphStore graphStore, CONFIG config, ProgressTracker progressTracker) {
        Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> trainedNRPipelineModel = getTrainedNRPipelineModel(this.modelCatalog, config.modelName(), config.username());
        return new NodeRegressionPredictPipelineExecutor(trainedNRPipelineModel.customInfo().pipeline(), config, this.executionContext, graphStore, config.graphName(), progressTracker, regressorFrom((Regressor.RegressorData) trainedNRPipelineModel.data()));
    }

    private static Regressor regressorFrom(Regressor.RegressorData regressorData) {
        switch (AnonymousClass1.$SwitchMap$org$neo4j$gds$ml$models$TrainingMethod[regressorData.trainerMethod().ordinal()]) {
            case 1:
                return new LinearRegressor((LinearRegressionData) regressorData);
            case 2:
                return new RandomForestRegressor((RandomForestRegressorData) regressorData);
            default:
                throw new IllegalStateException("No such regressor: " + regressorData.trainerMethod().name());
        }
    }

    private static Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> getTrainedNRPipelineModel(ModelCatalog modelCatalog, String str, String str2) {
        return modelCatalog.get(str2, str, Regressor.RegressorData.class, NodeRegressionPipelineTrainConfig.class, NodeRegressionPipelineModelInfo.class);
    }
}
