/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.regression.predict;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionData;
import org.neo4j.gds.ml.models.linearregression.LinearRegressor;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressor;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorData;
import org.neo4j.gds.ml.pipeline.node.regression.predict.NodeRegressionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.node.regression.predict.NodeRegressionPredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;

public class NodeRegressionPredictPipelineAlgorithmFactory<CONFIG extends NodeRegressionPredictPipelineBaseConfig>
extends GraphStoreAlgorithmFactory<NodeRegressionPredictPipelineExecutor, CONFIG> {
    private final ModelCatalog modelCatalog;
    private final ExecutionContext executionContext;

    NodeRegressionPredictPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.modelCatalog = executionContext.modelCatalog();
        this.executionContext = executionContext;
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        NodePropertyPredictPipeline trainingPipeline = ((NodeRegressionPipelineModelInfo)NodeRegressionPredictPipelineAlgorithmFactory.getTrainedNRPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo()).pipeline();
        return NodeRegressionPredictPipelineExecutor.progressTask(this.taskName(), trainingPipeline, graphStore);
    }

    public String taskName() {
        return "Node Classification Predict Pipeline";
    }

    public NodeRegressionPredictPipelineExecutor build(GraphStore graphStore, CONFIG configuration, ProgressTracker progressTracker) {
        Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> model = NodeRegressionPredictPipelineAlgorithmFactory.getTrainedNRPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        return new NodeRegressionPredictPipelineExecutor(((NodeRegressionPipelineModelInfo)model.customInfo()).pipeline(), (NodeRegressionPredictPipelineBaseConfig)configuration, this.executionContext, graphStore, progressTracker, NodeRegressionPredictPipelineAlgorithmFactory.regressorFrom((Regressor.RegressorData)model.data()));
    }

    private static Regressor regressorFrom(Regressor.RegressorData regressorData) {
        switch (regressorData.trainerMethod()) {
            case LinearRegression: {
                return new LinearRegressor((LinearRegressionData)regressorData);
            }
            case RandomForestRegression: {
                return new RandomForestRegressor((RandomForestRegressorData)regressorData);
            }
        }
        throw new IllegalStateException("No such regressor: " + regressorData.trainerMethod().name());
    }

    private static Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> getTrainedNRPipelineModel(ModelCatalog modelCatalog, String modelName, String username) {
        return modelCatalog.get(username, modelName, Regressor.RegressorData.class, NodeRegressionPipelineTrainConfig.class, NodeRegressionPipelineModelInfo.class);
    }
}

