package org.neo4j.gds.applications.algorithms.embeddings;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.compat.GdsVersionInfoProvider;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.TrainConfigTransformer;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainParameters;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelGraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.SingleLabelGraphSageTrain;
import org.neo4j.gds.termination.TerminationFlag;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/neo4j/gds/applications/algorithms/embeddings/GraphSageTrainAlgorithmFactory.class */
public class GraphSageTrainAlgorithmFactory {
    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphSageTrain create(Graph graph, GraphSageTrainConfig graphSageTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        String gdsVersion = GdsVersionInfoProvider.GDS_VERSION_INFO.gdsVersion();
        GraphSageTrainParameters parameters = TrainConfigTransformer.toParameters(graphSageTrainConfig);
        return graphSageTrainConfig.isMultiLabel() ? new MultiLabelGraphSageTrain(graph, parameters, ((Integer) graphSageTrainConfig.projectedFeatureDimension().orElseThrow()).intValue(), DefaultPool.INSTANCE, progressTracker, terminationFlag, gdsVersion, graphSageTrainConfig) : new SingleLabelGraphSageTrain(graph, parameters, DefaultPool.INSTANCE, progressTracker, terminationFlag, gdsVersion, graphSageTrainConfig);
    }
}
