package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import ai.sklearn4j.core.libraries.numpy.wrappers.Dim2DoubleNumpyWrapper;

/* loaded from: input_file:ai/sklearn4j/naive_bayes/GaussianNaiveBayes.class */
public class GaussianNaiveBayes extends BaseNaiveBayes {
    private NumpyArray<Double> classCounts = null;
    private NumpyArray<Double> classPriors = null;
    private NumpyArray<Long> classes = null;
    private NumpyArray<Double> priors = null;
    private String[] featureNamesIn = null;
    private int numberOfFeatures = 0;
    private NumpyArray<Double> sigma = null;
    private NumpyArray<Double> theta = null;

    @Override // ai.sklearn4j.naive_bayes.BaseNaiveBayes
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> numpyArray) {
        int i = numpyArray.getShape()[0];
        int i2 = this.classCounts.getShape()[0];
        int i3 = this.sigma.getShape()[1];
        double[][] dArr = new double[i][i2];
        double[][] array = ((Dim2DoubleNumpyWrapper) this.sigma.getWrapper()).getArray();
        double[][] array2 = ((Dim2DoubleNumpyWrapper) this.theta.getWrapper()).getArray();
        for (int i4 = 0; i4 < i2; i4++) {
            double d = 0.0d;
            for (int i5 = 0; i5 < i3; i5++) {
                d += Math.log(6.283185307179586d * array[i4][i5]);
            }
            double log = Math.log(this.classPriors.get(i4).doubleValue());
            for (int i6 = 0; i6 < i; i6++) {
                double d2 = 0.0d;
                for (int i7 = 0; i7 < i3; i7++) {
                    double doubleValue = numpyArray.get(i6, i7).doubleValue() - array2[i4][i7];
                    d2 += Math.pow(numpyArray.get(i6, i7).doubleValue() - array2[i4][i7], 2.0d) / array[i4][i7];
                }
                dArr[i6][i4] = ((-0.5d) * (d + d2)) + log;
            }
        }
        return NumpyArrayFactory.from(dArr);
    }

    public NumpyArray<Double> getClassCounts() {
        return this.classCounts;
    }

    public void setClassCounts(NumpyArray<Double> numpyArray) {
        this.classCounts = numpyArray;
    }

    public NumpyArray<Double> getClassPriors() {
        return this.classPriors;
    }

    public void setClassPriors(NumpyArray<Double> numpyArray) {
        this.classPriors = numpyArray;
    }

    public NumpyArray<Long> getClasses() {
        return this.classes;
    }

    public void setClasses(NumpyArray<Long> numpyArray) {
        this.classes = numpyArray;
    }

    public NumpyArray<Double> getPriors() {
        return this.priors;
    }

    public void setPriors(NumpyArray<Double> numpyArray) {
        this.priors = numpyArray;
    }

    public String[] getFeatureNamesIn() {
        return this.featureNamesIn;
    }

    public void setFeatureNamesIn(String[] strArr) {
        this.featureNamesIn = strArr;
    }

    public int getNumberOfFeatures() {
        return this.numberOfFeatures;
    }

    public void setNumberOfFeatures(int i) {
        this.numberOfFeatures = i;
    }

    public NumpyArray<Double> getSigma() {
        return this.sigma;
    }

    public void setSigma(NumpyArray<Double> numpyArray) {
        this.sigma = numpyArray;
    }

    public NumpyArray<Double> getTheta() {
        return this.theta;
    }

    public void setTheta(NumpyArray<Double> numpyArray) {
        this.theta = numpyArray;
    }
}
