package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.function.Function;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MultiMean;
import org.neo4j.gds.ml.core.functions.WeightedMultiMean;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/MeanAggregator.class */
public class MeanAggregator implements Aggregator {
    private final Weights<Matrix> weights;
    private final Function<Variable<Matrix>, Variable<Matrix>> activationFunction;
    private final ActivationFunction activation;

    public MeanAggregator(Weights<Matrix> weights, ActivationFunction activationFunction) {
        this.weights = weights;
        this.activation = activationFunction;
        this.activationFunction = activationFunction.activationFunction();
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public Variable<Matrix> aggregate(Variable<Matrix> variable, SubGraph subGraph) {
        return this.activationFunction.apply(MatrixMultiplyWithTransposedSecondOperand.of((Variable) subGraph.maybeRelationshipWeightsFunction.map(relationshipWeights -> {
            return new WeightedMultiMean(variable, relationshipWeights, subGraph);
        }).orElseGet(() -> {
            return new MultiMean(variable, subGraph.adjacency, subGraph.selfAdjacency);
        }), this.weights));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.weights);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public Aggregator.AggregatorType type() {
        return Aggregator.AggregatorType.MEAN;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public ActivationFunction activationFunction() {
        return this.activation;
    }

    public Matrix weightsData() {
        return this.weights.data();
    }
}
