package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/AdamOptimizer.class */
public class AdamOptimizer {
    private static final double CLIP_MAX = 5.0d;
    private static final double CLIP_MIN = -5.0d;
    private static final double DEFAULT_ALPHA = 0.001d;
    private final double alpha;
    private final double beta_1 = 0.9d;
    private final double beta_2 = 0.999d;
    private final double epsilon = 1.0E-8d;
    private final List<Weights<? extends Tensor<?>>> variables;
    private List<? extends Tensor<?>> momentumTerms;
    private List<? extends Tensor<?>> velocityTerms;
    private int iteration;

    public static long sizeInBytes(int i, int i2, int i3) {
        long sizeInBytes = Weights.sizeInBytes(i, i2) * i3;
        return MemoryUsage.sizeOfInstance(AdamOptimizer.class) + (2 * sizeInBytes) + (4 * sizeInBytes);
    }

    public AdamOptimizer(List<Weights<? extends Tensor<?>>> list) {
        this(list, DEFAULT_ALPHA);
    }

    public AdamOptimizer(List<Weights<? extends Tensor<?>>> list, double d) {
        this.beta_1 = 0.9d;
        this.beta_2 = 0.999d;
        this.epsilon = 1.0E-8d;
        this.iteration = 0;
        this.alpha = d;
        this.variables = list;
        this.momentumTerms = (List) list.stream().map(weights -> {
            return weights.data().zeros();
        }).collect(Collectors.toList());
        this.velocityTerms = List.copyOf(this.momentumTerms);
    }

    public synchronized void update(ComputationContext computationContext) {
        this.iteration++;
        this.variables.forEach(weights -> {
            computationContext.gradient(weights).mapInPlace(this::clip);
        });
        this.momentumTerms = (List) IntStream.range(0, this.variables.size()).mapToObj(i -> {
            return castAndAdd(this.momentumTerms.get(i).scalarMultiply(0.9d), computationContext.gradient(this.variables.get(i)).scalarMultiply(0.09999999999999998d));
        }).collect(Collectors.toList());
        this.velocityTerms = (List) IntStream.range(0, this.variables.size()).mapToObj(i2 -> {
            Variable variable = this.variables.get(i2);
            Tensor<?> tensor = this.velocityTerms.get(i2);
            Tensor gradient = computationContext.gradient(variable);
            return castAndAdd(tensor.scalarMultiply(0.999d), gradient.elementwiseProduct(gradient).scalarMultiply(0.0010000000000000009d));
        }).collect(Collectors.toList());
        List list = (List) this.momentumTerms.stream().map(tensor -> {
            return tensor.scalarMultiply(1.0d / (1.0d - Math.pow(0.9d, this.iteration)));
        }).collect(Collectors.toList());
        List list2 = (List) this.velocityTerms.stream().map(tensor2 -> {
            return tensor2.scalarMultiply(1.0d / (1.0d - Math.pow(0.999d, this.iteration)));
        }).collect(Collectors.toList());
        IntStream.range(0, this.variables.size()).forEach(i3 -> {
            this.variables.get(i3).data().addInPlace(((Tensor) list.get(i3)).scalarMultiply(-this.alpha).elementwiseProduct(((Tensor) list2.get(i3)).map(d -> {
                return 1.0d / (Math.sqrt(d) + 1.0E-8d);
            })));
        });
    }

    private Tensor<?> castAndAdd(Tensor<?> tensor, Tensor<?> tensor2) {
        if ((tensor instanceof Scalar) && (tensor2 instanceof Scalar)) {
            return ((Scalar) tensor).add((Scalar) tensor2);
        }
        if ((tensor instanceof Vector) && (tensor2 instanceof Vector)) {
            return ((Vector) tensor).add((Vector) tensor2);
        }
        if ((tensor instanceof Matrix) && (tensor2 instanceof Matrix)) {
            return ((Matrix) tensor).add((Matrix) tensor2);
        }
        throw new IllegalStateException("Mismatched tensors! Can only add same types");
    }

    private double clip(double d) {
        return d > CLIP_MAX ? CLIP_MAX : d < CLIP_MIN ? CLIP_MIN : d;
    }
}
