package org.neo4j.gds.embeddings.graphsage;

import com.carrotsearch.hppc.LongHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalLong;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.ImmutableRelationshipCursor;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.RelationshipWeights;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.WeightedUniformSampler;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.functions.PassthroughVariable;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.class */
public class GraphSageModelTrainer {
    private final long randomSeed;
    private Layer[] layers;
    private final boolean useWeights;
    private final double learningRate;
    private final double tolerance;
    private final int negativeSampleWeight;
    private final int concurrency;
    private final int epochs;
    private final int maxIterations;
    private final int maxSearchDepth;
    private final Function<Graph, List<LayerConfig>> layerConfigsFunction;
    private final FeatureFunction featureFunction;
    private final Collection<Weights<? extends Tensor<?>>> labelProjectionWeights;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;
    private final int batchSize;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$BatchTask.class */
    public static class BatchTask implements Runnable {
        private final Variable<Scalar> lossFunction;
        private final List<Weights<? extends Tensor<?>>> weightVariables;
        private List<? extends Tensor<?>> weightGradients;
        private final double tolerance;
        private boolean converged;
        private double prevLoss;

        BatchTask(Variable<Scalar> variable, List<Weights<? extends Tensor<?>>> list, double d) {
            this.lossFunction = variable;
            this.weightVariables = list;
            this.tolerance = d;
        }

        @Override // java.lang.Runnable
        public void run() {
            if (this.converged) {
                return;
            }
            ComputationContext computationContext = new ComputationContext();
            double value = computationContext.forward(this.lossFunction).value();
            this.converged = Math.abs(this.prevLoss - value) < this.tolerance;
            this.prevLoss = value;
            computationContext.backward(this.lossFunction);
            Stream<Weights<? extends Tensor<?>>> stream = this.weightVariables.stream();
            Objects.requireNonNull(computationContext);
            this.weightGradients = (List) stream.map((v1) -> {
                return r2.gradient(v1);
            }).collect(Collectors.toList());
        }

        public boolean converged() {
            return this.converged;
        }

        public double loss() {
            return this.prevLoss;
        }

        List<? extends Tensor<?>> weightGradients() {
            return this.weightGradients;
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$GraphSageTrainMetrics.class */
    public interface GraphSageTrainMetrics extends Model.Mappable {
        static GraphSageTrainMetrics empty() {
            return ImmutableGraphSageTrainMetrics.of((List<Double>) List.of(), false);
        }

        List<Double> epochLosses();

        boolean didConverge();

        @Value.Derived
        default int ranEpochs() {
            if (epochLosses().isEmpty()) {
                return 0;
            }
            return epochLosses().size();
        }

        @Value.Auxiliary
        @Value.Derived
        default Map<String, Object> toMap() {
            return Map.of("metrics", Map.of("epochLosses", epochLosses(), "didConverge", Boolean.valueOf(didConverge()), "ranEpochs", Integer.valueOf(ranEpochs())));
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$ModelTrainResult.class */
    public interface ModelTrainResult {
        GraphSageTrainMetrics metrics();

        Layer[] layers();

        static ModelTrainResult of(List<Double> list, boolean z, Layer[] layerArr) {
            return ImmutableModelTrainResult.builder().layers(layerArr).metrics(ImmutableGraphSageTrainMetrics.of(list, z)).build();
        }
    }

    public GraphSageModelTrainer(GraphSageTrainConfig graphSageTrainConfig, ExecutorService executorService, ProgressTracker progressTracker) {
        this(graphSageTrainConfig, executorService, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList());
    }

    public GraphSageModelTrainer(GraphSageTrainConfig graphSageTrainConfig, ExecutorService executorService, ProgressTracker progressTracker, FeatureFunction featureFunction, Collection<Weights<? extends Tensor<?>>> collection) {
        this.layerConfigsFunction = graph -> {
            return graphSageTrainConfig.layerConfigs(firstLayerColumns(graphSageTrainConfig, graph));
        };
        this.batchSize = graphSageTrainConfig.batchSize();
        this.learningRate = graphSageTrainConfig.learningRate();
        this.tolerance = graphSageTrainConfig.tolerance();
        this.negativeSampleWeight = graphSageTrainConfig.negativeSampleWeight();
        this.concurrency = graphSageTrainConfig.concurrency();
        this.epochs = graphSageTrainConfig.epochs();
        this.maxIterations = graphSageTrainConfig.maxIterations();
        this.maxSearchDepth = graphSageTrainConfig.searchDepth();
        this.featureFunction = featureFunction;
        this.labelProjectionWeights = collection;
        this.executor = executorService;
        this.progressTracker = progressTracker;
        this.useWeights = graphSageTrainConfig.hasRelationshipWeightProperty();
        this.randomSeed = ((Long) graphSageTrainConfig.randomSeed().orElse(Long.valueOf(ThreadLocalRandom.current().nextLong()))).longValue();
    }

    public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        this.progressTracker.beginSubTask();
        this.layers = (Layer[]) this.layerConfigsFunction.apply(graph).stream().map(LayerFactory::createLayer).toArray(i -> {
            return new Layer[i];
        });
        List<Weights<? extends Tensor<?>>> weights = getWeights();
        List<BatchTask> rangePartitionWithBatchSize = PartitionUtils.rangePartitionWithBatchSize(graph.nodeCount(), this.batchSize, partition -> {
            return new BatchTask(lossFunction(partition, graph, hugeObjectArray), weights, this.tolerance);
        });
        double d = Double.MAX_VALUE;
        boolean z = false;
        ArrayList arrayList = new ArrayList();
        int i2 = 1;
        while (true) {
            if (i2 > this.epochs) {
                break;
            }
            this.progressTracker.beginSubTask();
            double trainEpoch = trainEpoch(rangePartitionWithBatchSize, i2);
            arrayList.add(Double.valueOf(trainEpoch));
            this.progressTracker.endSubTask();
            if (Math.abs((trainEpoch - d) / d) < this.tolerance) {
                z = true;
                break;
            }
            d = trainEpoch;
            i2++;
        }
        this.progressTracker.endSubTask();
        return ModelTrainResult.of(arrayList, z, this.layers);
    }

    private double trainEpoch(List<BatchTask> list, int i) {
        AdamOptimizer adamOptimizer = new AdamOptimizer(getWeights(), this.learningRate);
        double d = Double.NaN;
        int i2 = 1;
        while (true) {
            if (i2 > this.maxIterations) {
                break;
            }
            this.progressTracker.beginSubTask();
            ParallelUtil.runWithConcurrency(this.concurrency, list, this.executor);
            d = list.stream().mapToDouble((v0) -> {
                return v0.loss();
            }).average().orElseThrow();
            if (list.stream().allMatch(batchTask -> {
                return batchTask.converged;
            })) {
                this.progressTracker.endSubTask();
                break;
            }
            adamOptimizer.update(TensorFunctions.averageTensors((List) list.stream().map((v0) -> {
                return v0.weightGradients();
            }).collect(Collectors.toList())));
            this.progressTracker.progressLogger().getLog().debug("Epoch %d LOSS: %.10f at iteration %d", new Object[]{Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(i2)});
            this.progressTracker.endSubTask();
            i2++;
        }
        return d;
    }

    private Variable<Scalar> lossFunction(Partition partition, Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        RelationshipWeights relationshipWeights;
        long batchIndex = getBatchIndex(partition, graph.nodeCount()) + this.randomSeed;
        long[] array = neighborBatch(graph, partition, batchIndex).toArray();
        LongHashSet longHashSet = new LongHashSet(array.length);
        longHashSet.addAll(array);
        long[] array2 = LongStream.concat(partition.stream(), LongStream.concat(Arrays.stream(array), negativeBatch(graph, Math.toIntExact(partition.nodeCount()), longHashSet, batchIndex))).toArray();
        Variable<Matrix> embeddings = GraphSageHelper.embeddings(graph, this.useWeights, array2, hugeObjectArray, this.layers, this.featureFunction);
        if (this.useWeights) {
            Objects.requireNonNull(graph);
            relationshipWeights = graph::relationshipProperty;
        } else {
            relationshipWeights = RelationshipWeights.UNWEIGHTED;
        }
        return new PassthroughVariable(new GraphSageLoss(relationshipWeights, embeddings, array2, this.negativeSampleWeight));
    }

    LongStream neighborBatch(Graph graph, Partition partition, long j) {
        LongStream.Builder builder = LongStream.builder();
        Random random = new Random(j);
        partition.consume(j2 -> {
            int nextInt = random.nextInt(this.maxSearchDepth) + 1;
            AtomicLong atomicLong = new AtomicLong(j2);
            while (nextInt > 0) {
                OptionalLong sampleOne = new NeighborhoodSampler(atomicLong.get() + nextInt).sampleOne(graph, j2);
                if (sampleOne.isPresent()) {
                    atomicLong.set(sampleOne.getAsLong());
                } else {
                    nextInt = 0;
                }
                nextInt--;
            }
            builder.add(atomicLong.get());
        });
        return builder.build();
    }

    LongStream negativeBatch(Graph graph, int i, LongHashSet longHashSet, long j) {
        long nodeCount = graph.nodeCount();
        return new WeightedUniformSampler(j).sample(LongStream.range(0L, nodeCount).mapToObj(j2 -> {
            return ImmutableRelationshipCursor.of(0L, j2, Math.pow(graph.degree(j2), 0.75d));
        }), nodeCount, i, j3 -> {
            return !longHashSet.contains(j3);
        });
    }

    private List<Weights<? extends Tensor<?>>> getWeights() {
        ArrayList arrayList = new ArrayList(this.labelProjectionWeights);
        arrayList.addAll((Collection) Arrays.stream(this.layers).flatMap(layer -> {
            return layer.weights().stream();
        }).collect(Collectors.toList()));
        return arrayList;
    }

    private static int getBatchIndex(Partition partition, long j) {
        return Math.toIntExact(Math.floorDiv(partition.startNode(), j));
    }

    private int firstLayerColumns(GraphSageTrainConfig graphSageTrainConfig, Graph graph) {
        return graphSageTrainConfig.projectedFeatureDimension().orElseGet(() -> {
            return Integer.valueOf(FeatureExtraction.featureCount(GraphSageHelper.featureExtractors(graph, graphSageTrainConfig)));
        }).intValue();
    }
}
