/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import java.util.stream.Stream;
import org.neo4j.gds.MutateComputationResultConsumer;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.loading.SingleTypeRelationships;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.MutateResult;
import org.neo4j.gds.ml.splitting.SplitRelationships;
import org.neo4j.gds.ml.splitting.SplitRelationshipsAlgorithmFactory;
import org.neo4j.gds.ml.splitting.SplitRelationshipsMutateConfig;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
import org.neo4j.gds.result.AbstractResultBuilder;

@GdsCallable(name="gds.alpha.ml.splitRelationships.mutate", description="Splits a graph into holdout and remaining relationship types and adds them to the graph.", executionMode=ExecutionMode.MUTATE_RELATIONSHIP)
public class SplitRelationshipsMutateSpec
implements AlgorithmSpec<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig, Stream<MutateResult>, SplitRelationshipsAlgorithmFactory> {
    public String name() {
        return "SplitRelationships";
    }

    public SplitRelationshipsAlgorithmFactory algorithmFactory(ExecutionContext executionContext) {
        return new SplitRelationshipsAlgorithmFactory();
    }

    public NewConfigFunction<SplitRelationshipsMutateConfig> newConfigFunction() {
        return (___, config) -> SplitRelationshipsMutateConfig.of((CypherMapWrapper)config);
    }

    public ComputationResultConsumer<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig, Stream<MutateResult>> computationResultConsumer() {
        return new MutateComputationResultConsumer<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig, MutateResult>(this::resultBuilder){

            protected void updateGraphStore(AbstractResultBuilder<?> resultBuilder, ComputationResult<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig> computationResult, ExecutionContext executionContext) {
                computationResult.result().ifPresent(splitResult -> {
                    GraphStore graphStore = computationResult.graphStore();
                    SingleTypeRelationships selectedRels = splitResult.selectedRels().build();
                    SingleTypeRelationships remainingRels = splitResult.remainingRels().build();
                    graphStore.addRelationshipType(remainingRels);
                    graphStore.addRelationshipType(selectedRels);
                    long holdoutWritten = selectedRels.topology().elementCount();
                    long remainingWritten = remainingRels.topology().elementCount();
                    resultBuilder.withRelationshipsWritten(holdoutWritten + remainingWritten);
                });
            }
        };
    }

    private AbstractResultBuilder<MutateResult> resultBuilder(ComputationResult<SplitRelationships, EdgeSplitter.SplitResult, SplitRelationshipsMutateConfig> computeResult, ExecutionContext executionContext) {
        return new MutateResult.Builder();
    }
}

