/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.node.regression.predict;

import java.util.Collection;
import java.util.List;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.FeaturesFactory;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.nodePropertyPrediction.regression.NodeRegressionPredict;
import org.neo4j.gds.ml.pipeline.ImmutablePipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.PredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.node.regression.predict.NodeRegressionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictPipeline;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

public class NodeRegressionPredictPipelineExecutor
extends PredictPipelineExecutor<NodeRegressionPredictPipelineBaseConfig, NodePropertyPredictPipeline, HugeDoubleArray> {
    private final Regressor regressor;
    private final PipelineGraphFilter predictGraphFilter;

    public NodeRegressionPredictPipelineExecutor(NodePropertyPredictPipeline pipeline, NodeRegressionPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, ProgressTracker progressTracker, Regressor regressor) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, progressTracker);
        this.regressor = regressor;
        this.predictGraphFilter = ImmutablePipelineGraphFilter.builder().nodeLabels(config.nodeLabelIdentifiers(graphStore)).relationshipTypes(config.internalRelationshipTypes(graphStore)).build();
    }

    public static Task progressTask(String taskName, NodePropertyPredictPipeline pipeline, GraphStore graphStore) {
        return Tasks.task((String)taskName, (Task)NodePropertyStepExecutor.tasks((List)pipeline.nodePropertySteps(), (long)graphStore.nodeCount()), (Task[])new Task[]{NodeRegressionPredict.progressTask((long)graphStore.nodeCount())});
    }

    protected PipelineGraphFilter nodePropertyStepFilter() {
        return this.predictGraphFilter;
    }

    protected HugeDoubleArray execute() {
        Graph nodesGraph = this.graphStore.getGraph(this.predictGraphFilter.nodeLabels());
        Features features = FeaturesFactory.extractLazyFeatures((Graph)nodesGraph, (List)((NodePropertyPredictPipeline)this.pipeline).featureProperties());
        if (features.featureDimension() != this.regressor.data().featureDimension()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Model expected features %s to have a total dimension of `%d`, but got `%d`.", (Object[])new Object[]{StringJoining.join((Collection)((NodePropertyPredictPipeline)this.pipeline).featureProperties()), this.regressor.data().featureDimension(), features.featureDimension()}));
        }
        return new NodeRegressionPredict(this.regressor, features, ((NodeRegressionPredictPipelineBaseConfig)this.config).concurrency(), this.progressTracker, this.terminationFlag).compute();
    }
}

