package org.deeplearning4j.spark.impl.common.score;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.class */
public abstract class BaseVaeReconstructionProbWithKeyFunction<K> extends BaseVaeScoreWithKeyFunction<K> {
    private final boolean useLogProbability;
    private final int numSamples;

    public BaseVaeReconstructionProbWithKeyFunction(Broadcast<INDArray> broadcast, Broadcast<String> broadcast2, boolean z, int i, int i2) {
        super(broadcast, broadcast2, i);
        this.useLogProbability = z;
        this.numSamples = i2;
    }

    @Override // org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction
    public INDArray computeScore(VariationalAutoencoder variationalAutoencoder, INDArray iNDArray) {
        return this.useLogProbability ? variationalAutoencoder.reconstructionLogProbability(iNDArray, this.numSamples) : variationalAutoencoder.reconstructionProbability(iNDArray, this.numSamples);
    }
}
