package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.ToLongFunction;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.DegreePartition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureStep;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkFeaturesAndLabelsExtractor.class */
final class LinkFeaturesAndLabelsExtractor {
    private LinkFeaturesAndLabelsExtractor() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation estimate(MemoryRange memoryRange, ToLongFunction<Map<RelationshipType, Long>> toLongFunction, String str) {
        return MemoryEstimations.builder().rangePerGraphDimension(str + " relationship features", (graphDimensions, num) -> {
            return memoryRange.apply(MemoryUsage::sizeOfDoubleArray).times(toLongFunction.applyAsLong(graphDimensions.relationshipCounts())).add(MemoryUsage.sizeOfInstance(HugeObjectArray.class));
        }).perGraphDimension(str + " relationship targets", (graphDimensions2, num2) -> {
            return MemoryRange.of(HugeIntArray.memoryEstimation(toLongFunction.applyAsLong(graphDimensions2.relationshipCounts())));
        }).build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static FeaturesAndLabels extractFeaturesAndLabels(Graph graph, List<LinkFeatureStep> list, int i, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        progressTracker.setSteps(graph.relationshipCount() * 2);
        Features extractFeatures = LinkFeatureExtractor.extractFeatures(graph, list, i, progressTracker, terminationFlag);
        return ImmutableFeaturesAndLabels.of(extractFeatures, extractLabels(graph, extractFeatures.size(), i, progressTracker, terminationFlag));
    }

    private static HugeIntArray extractLabels(Graph graph, long j, int i, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        HugeIntArray newArray = HugeIntArray.newArray(j);
        List<DegreePartition> degreePartition = PartitionUtils.degreePartition(graph, i, Function.identity(), Optional.of(100));
        ArrayList arrayList = new ArrayList();
        MutableLong mutableLong = new MutableLong();
        for (DegreePartition degreePartition2 : degreePartition) {
            Long value = mutableLong.getValue();
            arrayList.add(() -> {
                MutableLong mutableLong2 = new MutableLong(value);
                Graph concurrentCopy = graph.concurrentCopy();
                degreePartition2.consume(j2 -> {
                    concurrentCopy.forEachRelationship(j2, -10.0d, (j2, j3, d) -> {
                        if (d != 0.0d && d != 1.0d) {
                            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Label should be either `1` or `0`. But got %f for relationship (%d, %d)", new Object[]{Double.valueOf(d), Long.valueOf(j2), Long.valueOf(j3)}));
                        }
                        newArray.set(mutableLong2.getAndIncrement(), (int) d);
                        return true;
                    });
                });
                progressTracker.logSteps(degreePartition2.totalDegree());
            });
            mutableLong.add(degreePartition2.totalDegree());
        }
        RunWithConcurrency.builder().concurrency(i).tasks(arrayList).terminationFlag(terminationFlag).run();
        return newArray;
    }
}
