package org.neo4j.gds.ml.features;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.embeddings.EmbeddingUtils;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixConstant;
import org.neo4j.gds.ml.batch.Batch;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.nodeproperties.ValueType;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/features/FeatureExtraction.class */
public final class FeatureExtraction {
    private FeatureExtraction() {
    }

    public static void extract(long j, long j2, List<FeatureExtractor> list, FeatureConsumer featureConsumer) {
        int i = 0;
        for (FeatureExtractor featureExtractor : list) {
            if (featureExtractor instanceof ScalarFeatureExtractor) {
                featureConsumer.acceptScalar(j2, i, ((ScalarFeatureExtractor) featureExtractor).extract(j));
            } else {
                if (!(featureExtractor instanceof ArrayFeatureExtractor)) {
                    throw new IllegalStateException("Only ScalarFeatureExtractor and ArrayFeatureExtractor are handled");
                }
                featureConsumer.acceptArray(j2, i, ((ArrayFeatureExtractor) featureExtractor).extract(j));
            }
            i += featureExtractor.dimension();
        }
    }

    public static MatrixConstant extract(Batch batch, List<FeatureExtractor> list) {
        int size = batch.size();
        final int featureCount = featureCount(list);
        final double[] dArr = new double[size * featureCount];
        FeatureConsumer featureConsumer = new FeatureConsumer() { // from class: org.neo4j.gds.ml.features.FeatureExtraction.1
            @Override // org.neo4j.gds.ml.features.FeatureConsumer
            public void acceptScalar(long j, int i, double d) {
                dArr[(((int) j) * featureCount) + i] = d;
            }

            @Override // org.neo4j.gds.ml.features.FeatureConsumer
            public void acceptArray(long j, int i, double[] dArr2) {
                System.arraycopy(dArr2, 0, dArr, (((int) j) * featureCount) + i, dArr2.length);
            }
        };
        int i = 0;
        Iterator<Long> it = batch.nodeIds().iterator();
        while (it.hasNext()) {
            extract(it.next().longValue(), i, list, featureConsumer);
            i++;
        }
        return new MatrixConstant(dArr, size, featureCount);
    }

    public static HugeObjectArray<double[]> extract(Graph graph, List<FeatureExtractor> list, HugeObjectArray<double[]> hugeObjectArray) {
        int featureCount = featureCount(list);
        hugeObjectArray.setAll(j -> {
            return new double[featureCount];
        });
        HugeObjectArrayFeatureConsumer hugeObjectArrayFeatureConsumer = new HugeObjectArrayFeatureConsumer(hugeObjectArray);
        graph.forEachNode(j2 -> {
            extract(j2, j2, list, hugeObjectArrayFeatureConsumer);
            return true;
        });
        return hugeObjectArray;
    }

    public static int featureCount(Collection<FeatureExtractor> collection) {
        return collection.stream().mapToInt((v0) -> {
            return v0.dimension();
        }).sum();
    }

    public static List<FeatureExtractor> propertyExtractors(Graph graph, Collection<String> collection) {
        return propertyExtractors(graph, collection, 0L);
    }

    public static List<FeatureExtractor> propertyExtractors(Graph graph, Collection<String> collection, long j) {
        return (List) collection.stream().map(str -> {
            ValueType valueType = graph.nodeProperties(str).valueType();
            if (ValueType.DOUBLE_ARRAY == valueType || ValueType.FLOAT_ARRAY == valueType) {
                return new ArrayPropertyExtractor(EmbeddingUtils.getCheckedDoubleArrayNodeProperty(graph, str, j).length, graph, str);
            }
            if (ValueType.LONG_ARRAY == valueType) {
                return new LongArrayPropertyExtractor(EmbeddingUtils.getCheckedLongArrayNodeProperty(graph, str, j).length, graph, str);
            }
            if (ValueType.DOUBLE == valueType || ValueType.LONG == valueType) {
                return new ScalarPropertyExtractor(graph, str);
            }
            throw new IllegalStateException(StringFormatting.formatWithLocale("Unknown ValueType %s", new Object[]{valueType}));
        }).collect(Collectors.toList());
    }

    public static int featureCountWithBias(Graph graph, List<String> list) {
        List<FeatureExtractor> propertyExtractors = propertyExtractors(graph, list);
        propertyExtractors.add(new BiasFeature());
        return featureCount(propertyExtractors);
    }
}
