package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.NormalizeRows;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.subgraph.SubGraph;
import org.neo4j.gds.ml.features.BiasFeature;
import org.neo4j.gds.ml.features.DegreeFeatureExtractor;
import org.neo4j.gds.ml.features.FeatureExtraction;
import org.neo4j.gds.ml.features.FeatureExtractor;
import org.neo4j.gds.ml.features.HugeObjectArrayFeatureConsumer;
import org.neo4j.graphalgo.NodeLabel;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeMapping;
import org.neo4j.graphalgo.api.schema.GraphSchema;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimations;
import org.neo4j.graphalgo.core.utils.mem.MemoryRange;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageHelper.class */
public final class GraphSageHelper {
    private GraphSageHelper() {
    }

    public static Variable<Matrix> embeddings(Graph graph, boolean z, long[] jArr, HugeObjectArray<double[]> hugeObjectArray, Layer[] layerArr, FeatureFunction featureFunction) {
        List list = (List) Arrays.stream(layerArr).map(layer -> {
            Objects.requireNonNull(layer);
            return layer::neighborhoodFunction;
        }).collect(Collectors.toList());
        Collections.reverse(list);
        List<SubGraph> buildSubGraphs = SubGraph.buildSubGraphs(jArr, list, graph, z);
        Variable<Matrix> apply = featureFunction.apply(graph, buildSubGraphs.get(buildSubGraphs.size() - 1).nextNodes, hugeObjectArray);
        for (int length = layerArr.length - 1; length >= 0; length--) {
            apply = layerArr[(layerArr.length - length) - 1].aggregator().aggregate(apply, buildSubGraphs.get(length));
        }
        return new NormalizeRows(apply);
    }

    public static MemoryEstimation embeddingsEstimation(GraphSageTrainConfig graphSageTrainConfig, long j, long j2, int i, boolean z) {
        boolean isMultiLabel = graphSageTrainConfig.isMultiLabel();
        List<LayerConfig> layerConfigs = graphSageTrainConfig.layerConfigs();
        int size = layerConfigs.size();
        MemoryEstimations.Builder startField = MemoryEstimations.builder("computationGraph").startField("subgraphs");
        ArrayList arrayList = new ArrayList(size + 1);
        ArrayList arrayList2 = new ArrayList(size + 1);
        arrayList.add(Long.valueOf(j));
        arrayList2.add(Long.valueOf(j));
        for (int i2 = 0; i2 < size; i2++) {
            long sampleSize = layerConfigs.get(i2).sampleSize();
            Long l = (Long) arrayList.get(i2);
            Long l2 = (Long) arrayList2.get(i2);
            long min = Math.min(l.longValue(), j2);
            long min2 = Math.min(l2.longValue() * (sampleSize + 1), j2);
            arrayList.add(Long.valueOf(min));
            arrayList2.add(Long.valueOf(min2));
            startField.add(MemoryEstimations.of("subgraph " + (i2 + 1), MemoryRange.of(MemoryUsage.sizeOfIntArray(l.longValue()) + MemoryUsage.sizeOfObjectArray(l.longValue()) + (l.longValue() * MemoryUsage.sizeOfIntArray(0L)) + MemoryUsage.sizeOfLongArray(min), MemoryUsage.sizeOfIntArray(l2.longValue()) + MemoryUsage.sizeOfObjectArray(l2.longValue()) + (l2.longValue() * MemoryUsage.sizeOfIntArray(sampleSize)) + MemoryUsage.sizeOfLongArray(min2))));
        }
        Collections.reverse(arrayList);
        Collections.reverse(arrayList2);
        MemoryEstimations.Builder builder = MemoryEstimations.builder();
        for (int i3 = 0; i3 < size; i3++) {
            LayerConfig layerConfig = layerConfigs.get(i3);
            Long l3 = (Long) arrayList.get(i3);
            Long l4 = (Long) arrayList2.get(i3);
            Long l5 = (Long) arrayList.get(i3 + 1);
            Long l6 = (Long) arrayList2.get(i3 + 1);
            if (i3 == 0) {
                int featuresSize = graphSageTrainConfig.featuresSize();
                MemoryRange of = MemoryRange.of(MemoryUsage.sizeOfDoubleArray(l3.longValue() * featuresSize), MemoryUsage.sizeOfDoubleArray(l4.longValue() * featuresSize));
                if (isMultiLabel) {
                    of = of.add(MemoryRange.of(MemoryUsage.sizeOfDoubleArray(featuresSize)));
                }
                builder.fixed("firstLayer", of);
            }
            Aggregator.AggregatorType aggregatorType = layerConfig.aggregatorType();
            int embeddingDimension = graphSageTrainConfig.embeddingDimension();
            builder.fixed(StringFormatting.formatWithLocale("%s %d", new Object[]{aggregatorType.name(), Integer.valueOf(i3 + 1)}), aggregatorType.memoryEstimation(l5.longValue(), l6.longValue(), l3.longValue(), l4.longValue(), layerConfig.cols(), embeddingDimension));
            if (i3 == size - 1) {
                builder.fixed("normalizeRows", MemoryRange.of(MemoryUsage.sizeOfDoubleArray(l5.longValue() * embeddingDimension), MemoryUsage.sizeOfDoubleArray(l6.longValue() * embeddingDimension)));
            }
        }
        MemoryEstimations.Builder endField = startField.endField();
        if (isMultiLabel) {
            endField.fixed("multiLabelFeatureFunction", MemoryRange.of(MemoryUsage.sizeOfObjectArray(((Long) arrayList.get(0)).longValue()), MemoryUsage.sizeOfObjectArray(((Long) arrayList2.get(0)).longValue())).add(MemoryRange.of(MemoryUsage.sizeOfObjectArray(i))));
        }
        MemoryEstimations.Builder addComponentsOf = endField.startField("forward").addComponentsOf(builder.build());
        if (z) {
            addComponentsOf = addComponentsOf.endField().startField("backward").addComponentsOf(builder.build());
        }
        return addComponentsOf.endField().build();
    }

    public static HugeObjectArray<double[]> initializeFeatures(Graph graph, GraphSageTrainConfig graphSageTrainConfig, AllocationTracker allocationTracker) {
        HugeObjectArray newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount(), allocationTracker);
        return graphSageTrainConfig.isMultiLabel() ? initializeMultiLabelFeatures(graph, graphSageTrainConfig, newArray) : initializeSingleLabelFeatures(graph, graphSageTrainConfig, newArray);
    }

    private static HugeObjectArray<double[]> initializeSingleLabelFeatures(Graph graph, GraphSageTrainConfig graphSageTrainConfig, HugeObjectArray<double[]> hugeObjectArray) {
        List<FeatureExtractor> propertyExtractors = FeatureExtraction.propertyExtractors(graph, graphSageTrainConfig.featureProperties());
        return FeatureExtraction.extract(graph, graphSageTrainConfig.degreeAsProperty() ? (List) Stream.concat(propertyExtractors.stream(), Stream.of(new DegreeFeatureExtractor(graph))).collect(Collectors.toList()) : propertyExtractors, hugeObjectArray);
    }

    private static HugeObjectArray<double[]> initializeMultiLabelFeatures(Graph graph, GraphSageTrainConfig graphSageTrainConfig, HugeObjectArray<double[]> hugeObjectArray) {
        HugeObjectArrayFeatureConsumer hugeObjectArrayFeatureConsumer = new HugeObjectArrayFeatureConsumer(hugeObjectArray);
        HashMap hashMap = new HashMap();
        Map<NodeLabel, Set<String>> filteredPropertyKeysPerNodeLabel = filteredPropertyKeysPerNodeLabel(graph, graphSageTrainConfig);
        HashMap hashMap2 = new HashMap();
        graph.forEachNode(j -> {
            NodeLabel labelOf = labelOf(graph, j);
            List list = (List) hashMap2.computeIfAbsent(labelOf, nodeLabel -> {
                ArrayList arrayList = new ArrayList(FeatureExtraction.propertyExtractors(graph, (Set) filteredPropertyKeysPerNodeLabel.get(nodeLabel), j));
                if (graphSageTrainConfig.degreeAsProperty()) {
                    arrayList.add(new DegreeFeatureExtractor(graph));
                }
                arrayList.add(new BiasFeature());
                return arrayList;
            });
            hugeObjectArray.set(j, new double[((Integer) hashMap.computeIfAbsent(labelOf, nodeLabel2 -> {
                return Integer.valueOf(FeatureExtraction.featureCount((Collection) hashMap2.get(nodeLabel2)));
            })).intValue()]);
            FeatureExtraction.extract(j, j, list, hugeObjectArrayFeatureConsumer);
            return true;
        });
        return hugeObjectArray;
    }

    public static Map<NodeLabel, Set<String>> propertyKeysPerNodeLabel(GraphSchema graphSchema) {
        return (Map) graphSchema.nodeSchema().properties().entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return ((Map) entry.getValue()).keySet();
        }));
    }

    private static Map<NodeLabel, Set<String>> filteredPropertyKeysPerNodeLabel(Graph graph, GraphSageTrainConfig graphSageTrainConfig) {
        return (Map) propertyKeysPerNodeLabel(graph.schema()).entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            Stream stream = graphSageTrainConfig.featureProperties().stream();
            Set set = (Set) entry.getValue();
            Objects.requireNonNull(set);
            return (Set) stream.filter((v1) -> {
                return r1.contains(v1);
            }).collect(Collectors.toSet());
        }));
    }

    private static NodeLabel labelOf(NodeMapping nodeMapping, long j) {
        AtomicReference atomicReference = new AtomicReference();
        MutableInt mutableInt = new MutableInt(0);
        nodeMapping.forEachNodeLabel(j, nodeLabel -> {
            atomicReference.set(nodeLabel);
            return mutableInt.getAndIncrement() == 0;
        });
        if (mutableInt.intValue() != 1) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Each node must have exactly one label: nodeId=%d, labels=%s", new Object[]{Long.valueOf(j), nodeMapping.nodeLabels(j)}));
        }
        return (NodeLabel) atomicReference.get();
    }
}
