package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.ScikitLearnCoreException;
import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.utils.ExtMath;
import ai.sklearn4j.utils.Preprocessing;

/* loaded from: input_file:ai/sklearn4j/naive_bayes/BernoulliNaiveBayes.class */
public class BernoulliNaiveBayes extends BaseNaiveBayes {
    private NumpyArray<Double> featureLogProbabilities = null;
    private NumpyArray<Double> classLogPrior = null;
    private NumpyArray<Double> featureCounts = null;
    private double binarizationThreshold = 0.0d;

    @Override // ai.sklearn4j.naive_bayes.BaseNaiveBayes
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> numpyArray) {
        NumpyArray<Double> binarizeInput = Preprocessing.binarizeInput(numpyArray, this.binarizationThreshold);
        int i = this.featureLogProbabilities.getShape()[1];
        int i2 = binarizeInput.getShape()[1];
        if (i != i2) {
            throw new ScikitLearnCoreException(String.format("Expected input with %d features, got %d instead.", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        NumpyArray<Double> log = Numpy.log(Numpy.add((NumpyArray) Numpy.multiply(Numpy.exp(this.featureLogProbabilities), -1.0d), 1.0d));
        return Numpy.add(ExtMath.dot(binarizeInput, Numpy.subtract(this.featureLogProbabilities, log).transpose()), Numpy.add(this.classLogPrior, Numpy.sum(log, 1)));
    }

    public NumpyArray<Double> getFeatureLogProbabilities() {
        return this.featureLogProbabilities;
    }

    public void setFeatureLogProbabilities(NumpyArray<Double> numpyArray) {
        this.featureLogProbabilities = numpyArray;
    }

    public NumpyArray<Double> getClassLogPrior() {
        return this.classLogPrior;
    }

    public void setClassLogPrior(NumpyArray<Double> numpyArray) {
        this.classLogPrior = numpyArray;
    }

    public NumpyArray<Double> getFeatureCounts() {
        return this.featureCounts;
    }

    public void setFeatureCount(NumpyArray<Double> numpyArray) {
        this.featureCounts = numpyArray;
    }

    public double getBinarizationThreshold() {
        return this.binarizationThreshold;
    }

    public void setBinarizationThreshold(double d) {
        this.binarizationThreshold = d;
    }
}
