/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.kge;

import java.util.stream.Stream;
import org.neo4j.gds.api.DatabaseId;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.core.write.RelationshipExporterBuilder;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.AlgorithmSpecProgressTrackerProvider;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.kge.KGEPredictAlgorithmFactory;
import org.neo4j.gds.ml.kge.KGEPredictResult;
import org.neo4j.gds.ml.kge.KGEPredictWriteConfig;
import org.neo4j.gds.ml.kge.KGEWriteResult;
import org.neo4j.gds.ml.kge.TopKMapComputer;
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
import org.neo4j.gds.similarity.nodesim.TopKGraph;
import org.neo4j.gds.similarity.nodesim.TopKMap;

@GdsCallable(name="gds.ml.kge.predict.write", description="Predicts new relationships using an existing KGE model", executionMode=ExecutionMode.WRITE_RELATIONSHIP)
public class KGEPredictWriteSpec
implements AlgorithmSpec<TopKMapComputer, KGEPredictResult, KGEPredictWriteConfig, Stream<KGEWriteResult>, KGEPredictAlgorithmFactory<KGEPredictWriteConfig>> {
    public String name() {
        return "KGEPredictWrite";
    }

    public KGEPredictAlgorithmFactory<KGEPredictWriteConfig> algorithmFactory(ExecutionContext executionContext) {
        return new KGEPredictAlgorithmFactory<KGEPredictWriteConfig>();
    }

    public NewConfigFunction<KGEPredictWriteConfig> newConfigFunction() {
        return (__, config) -> KGEPredictWriteConfig.of(config);
    }

    public ComputationResultConsumer<TopKMapComputer, KGEPredictResult, KGEPredictWriteConfig, Stream<KGEWriteResult>> computationResultConsumer() {
        return (computationResult, executionContext) -> {
            KGEWriteResult.Builder builder = new KGEWriteResult.Builder();
            if (computationResult.result().isEmpty()) {
                return Stream.of(builder.build());
            }
            Graph graph = computationResult.graph();
            TopKMap topKMap = ((KGEPredictResult)computationResult.result().get()).topKMap();
            TopKGraph topKGraph = new TopKGraph(graph, topKMap);
            KGEPredictWriteConfig config = (KGEPredictWriteConfig)computationResult.config();
            try (ProgressTimer ignored = ProgressTimer.start(arg_0 -> ((KGEWriteResult.Builder)builder).withWriteMillis(arg_0));){
                executionContext.relationshipExporterBuilder().withGraph((Graph)topKGraph).withIdMappingOperator(arg_0 -> ((TopKGraph)topKGraph).toOriginalNodeId(arg_0)).withTerminationFlag(((TopKMapComputer)computationResult.algorithm()).getTerminationFlag()).withProgressTracker(AlgorithmSpecProgressTrackerProvider.createProgressTracker((String)this.name(), (long)graph.nodeCount(), (Concurrency)RelationshipExporterBuilder.TYPED_DEFAULT_WRITE_CONCURRENCY, (ExecutionContext)executionContext)).withArrowConnectionInfo(config.arrowConnectionInfo(), computationResult.graphStore().databaseInfo().remoteDatabaseId().map(DatabaseId::databaseName)).withResultStore(config.resolveResultStore(computationResult.resultStore())).withJobId(config.jobId()).build().write(config.writeRelationshipType(), config.writeProperty());
            }
            builder.withComputeMillis(computationResult.computeMillis());
            builder.withPreProcessingMillis(computationResult.preProcessingMillis());
            builder.withRelationshipsWritten(topKGraph.relationshipCount());
            builder.withConfig(config);
            return Stream.of(builder.build());
        };
    }
}

