package org.neo4j.gds.ml.models.mlp;

import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MatrixVectorSum;
import org.neo4j.gds.ml.core.functions.Relu;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;

/* loaded from: input_file:org/neo4j/gds/ml/models/mlp/MLPClassifier.class */
public final class MLPClassifier implements Classifier {
    private final MLPClassifierData data;

    public MLPClassifier(MLPClassifierData mLPClassifierData) {
        this.data = mLPClassifierData;
    }

    @Override // org.neo4j.gds.ml.models.Classifier
    public double[] predictProbabilities(double[] dArr) {
        return new ComputationContext().forward(predictionsVariable(Constant.matrix(dArr, 1, dArr.length))).data();
    }

    @Override // org.neo4j.gds.ml.models.Classifier
    public Matrix predictProbabilities(Batch batch, Features features) {
        return new ComputationContext().forward(predictionsVariable(Objective.batchFeatureMatrix(batch, features)));
    }

    @Override // org.neo4j.gds.ml.models.Classifier
    public MLPClassifierData data() {
        return this.data;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Variable<Matrix> predictionsVariable(Constant<Matrix> constant) {
        Constant<Matrix> constant2 = constant;
        for (int i = 0; i < this.data.depth() - 1; i++) {
            constant2 = new Relu<>(new MatrixVectorSum(MatrixMultiplyWithTransposedSecondOperand.of(constant2, this.data.weights().get(i)), this.data.biases().get(i)), NegativeSampler.NEGATIVE);
        }
        return new Softmax(constant2);
    }
}
