package ai.sklearn4j.naive_bayes;

import ai.sklearn4j.core.libraries.numpy.Numpy;
import ai.sklearn4j.core.libraries.numpy.NumpyArray;
import ai.sklearn4j.core.libraries.numpy.NumpyArrayFactory;
import java.util.List;

/* loaded from: input_file:ai/sklearn4j/naive_bayes/CategoricalNaiveBayes.class */
public class CategoricalNaiveBayes extends BaseNaiveBayes {
    private List<NumpyArray<Double>> featureLogProbabilities = null;
    private NumpyArray<Double> classLogPrior = null;
    private NumpyArray<Double> categoryCounts = null;
    private NumpyArray<Long> numberOfCategories = null;

    @Override // ai.sklearn4j.naive_bayes.BaseNaiveBayes
    protected NumpyArray<Double> jointLogLikelihood(NumpyArray<Double> numpyArray) {
        NumpyArray<Double> arrayOfDoubleWithShape = NumpyArrayFactory.arrayOfDoubleWithShape(new int[]{numpyArray.getShape()[0], this.classCounts.getShape()[0]});
        for (int i = 0; i < getNumberOfFeatures(); i++) {
            int[] arrayFirstDimension = getArrayFirstDimension(numpyArray, i);
            NumpyArray<Double> numpyArray2 = this.featureLogProbabilities.get(i);
            int i2 = this.classes.getShape()[0];
            double[][] dArr = new double[i2][arrayFirstDimension.length];
            for (int i3 = 0; i3 < i2; i3++) {
                for (int i4 = 0; i4 < arrayFirstDimension.length; i4++) {
                    dArr[i3][i4] = numpyArray2.get(i3, arrayFirstDimension[i4]).doubleValue();
                }
            }
            arrayOfDoubleWithShape = Numpy.add(arrayOfDoubleWithShape, NumpyArrayFactory.from(dArr).transpose());
        }
        return Numpy.add(arrayOfDoubleWithShape, this.classLogPrior);
    }

    private int[] getArrayFirstDimension(NumpyArray<Double> numpyArray, int i) {
        int[] iArr = new int[numpyArray.getShape()[0]];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = (int) numpyArray.get(i2, i).doubleValue();
        }
        return iArr;
    }

    public NumpyArray<Double> getClassLogPrior() {
        return this.classLogPrior;
    }

    public void setClassLogPrior(NumpyArray<Double> numpyArray) {
        this.classLogPrior = numpyArray;
    }

    public List<NumpyArray<Double>> getFeatureLogProbabilities() {
        return this.featureLogProbabilities;
    }

    public void setFeatureLogProbabilities(List<NumpyArray<Double>> list) {
        this.featureLogProbabilities = list;
    }
}
