package com.gengoai.apollo.ml.model;

import com.gengoai.LogUtils;
import com.gengoai.MultithreadedStopwatch;
import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.function.Functional;
import com.gengoai.math.Math2;
import java.util.Arrays;
import java.util.Iterator;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import lombok.NonNull;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:com/gengoai/apollo/ml/model/NaiveBayes.class */
public class NaiveBayes extends SingleSourceModel<Parameters, NaiveBayes> {
    private static final long serialVersionUID = 1;
    protected double[][] conditionals;
    protected double[] priors;
    protected ModelType modelType;
    private static final Logger log = Logger.getLogger(NaiveBayes.class.getName());
    public static final ParameterDef<ModelType> modelTypeParam = ParameterDef.param("modelType", ModelType.class);

    /* loaded from: input_file:com/gengoai/apollo/ml/model/NaiveBayes$ModelType.class */
    public enum ModelType {
        Multinomial,
        Bernoulli { // from class: com.gengoai.apollo.ml.model.NaiveBayes.ModelType.1
            @Override // com.gengoai.apollo.ml.model.NaiveBayes.ModelType
            double convertValue(double d) {
                return d > 0.0d ? 1.0d : 0.0d;
            }

            @Override // com.gengoai.apollo.ml.model.NaiveBayes.ModelType
            double normalize(double d, double d2, double d3, double d4) {
                return (d + 1.0d) / (d2 + 2.0d);
            }

            @Override // com.gengoai.apollo.ml.model.NaiveBayes.ModelType
            NDArray distribution(NDArray nDArray, double[] dArr, double[][] dArr2) {
                NDArray columnVector = NDArrayFactory.ND.columnVector(dArr);
                for (int i = 0; i < dArr.length; i++) {
                    for (int i2 = 0; i2 < dArr2.length; i2++) {
                        double d = columnVector.get(i);
                        if (nDArray.get(i2) != 0.0d) {
                            columnVector.set(i, d + Math2.safeLog(dArr2[i2][i]));
                        } else {
                            columnVector.set(i, d + Math2.safeLog(1.0d - dArr2[i2][i]));
                        }
                    }
                }
                columnVector.mapi(FastMath::exp);
                return columnVector;
            }
        },
        Complementary;

        double convertValue(double d) {
            return d;
        }

        NDArray distribution(NDArray nDArray, double[] dArr, double[][] dArr2) {
            NDArray columnVector = NDArrayFactory.ND.columnVector(dArr);
            nDArray.forEachSparse((j, d) -> {
                for (int i = 0; i < dArr.length; i++) {
                    columnVector.set(i, columnVector.get(i) + (d * dArr2[(int) j][i]));
                }
            });
            return columnVector.mapi(Math::exp);
        }

        double normalize(double d, double d2, double d3, double d4) {
            return (d + 1.0d) / (d3 + d4);
        }
    }

    /* loaded from: input_file:com/gengoai/apollo/ml/model/NaiveBayes$Parameters.class */
    public static class Parameters extends SingleSourceFitParameters<Parameters> {
        public final ParamMap<Parameters>.Parameter<ModelType> modelType = parameter(NaiveBayes.modelTypeParam, ModelType.Bernoulli);
    }

    public NaiveBayes() {
        this(new Parameters());
    }

    public NaiveBayes(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public NaiveBayes(@NonNull Consumer<Parameters> consumer) {
        this((Parameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v87, types: [long] */
    /* JADX WARN: Type inference failed for: r0v94, types: [double] */
    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        DataSet cache = dataSet.cache();
        Validation.checkArgument(cache.size() > 0, "Must have at least 1 training example.");
        int dimension = (int) cache.getMetadata((String) ((Parameters) this.parameters).input.value()).getDimension();
        int dimension2 = (int) cache.getMetadata((String) ((Parameters) this.parameters).output.value()).getDimension();
        Level level = log.getLevel();
        if (!((Boolean) ((Parameters) this.parameters).verbose.value()).booleanValue()) {
            log.setLevel(Level.OFF);
        }
        MultithreadedStopwatch multithreadedStopwatch = new MultithreadedStopwatch(getClass().getName());
        multithreadedStopwatch.start();
        LogUtils.logInfo(log, "Beginning training of Naive Bayes Classifier over {2} examples with {0} features and {1} labels.", new Object[]{Integer.valueOf(dimension), Integer.valueOf(dimension2), Long.valueOf(cache.size())});
        this.conditionals = new double[dimension][dimension2];
        this.priors = new double[dimension2];
        this.modelType = (ModelType) ((Parameters) this.parameters).modelType.value();
        double[] dArr = new double[dimension2];
        double d = 0.0d;
        Iterator<Datum> it = cache.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            d += 1.0d;
            NDArray asNDArray = next.get(((Parameters) this.parameters).input.value()).asNDArray();
            NDArray asNDArray2 = next.get(((Parameters) this.parameters).output.value()).asNDArray();
            int argmax = (int) (asNDArray2.shape().isScalar() ? asNDArray2.get(0L) : asNDArray2.argmax());
            double[] dArr2 = this.priors;
            dArr2[argmax] = dArr2[argmax] + 1.0d;
            asNDArray.forEachSparse((j, d2) -> {
                dArr[argmax] = dArr[argmax] + d2;
                double[] dArr3 = this.conditionals[(int) j];
                dArr3[argmax] = dArr3[argmax] + this.modelType.convertValue(d2);
            });
        }
        for (int i = 0; i < this.conditionals.length; i++) {
            double[] copyOf = Arrays.copyOf(this.conditionals[i], this.conditionals[i].length);
            for (int i2 = 0; i2 < this.priors.length; i2++) {
                if (this.modelType == ModelType.Complementary) {
                    double d3 = 0.0d;
                    double d4 = 0.0d;
                    for (int i3 = 0; i3 < this.priors.length; i3++) {
                        if (i3 != i2) {
                            d3 += copyOf[i3];
                            d4 += dArr[i3];
                        }
                    }
                    this.conditionals[i][i2] = Math2.safeLog(this.modelType.normalize(d3, this.priors[i2], d4, dimension));
                } else {
                    this.conditionals[i][i2] = Math2.safeLog(this.modelType.normalize(this.conditionals[i][i2], this.priors[i2], dArr[i2], dimension));
                }
            }
        }
        for (int i4 = 0; i4 < this.priors.length; i4++) {
            this.priors[i4] = Math2.safeLog(this.priors[i4] / d);
        }
        multithreadedStopwatch.stop();
        LogUtils.logInfo(log, "Completed training in {0}", new Object[]{multithreadedStopwatch});
        log.setLevel(level);
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public LabelType getLabelType(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (str.equals(((Parameters) this.parameters).output.value())) {
            return LabelType.classificationType(this.priors.length);
        }
        throw new IllegalArgumentException("'" + str + "' is not a valid output for this ");
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        return this.modelType.distribution(observation.asNDArray(), this.priors, this.conditionals);
    }

    @Override // com.gengoai.apollo.ml.model.SingleSourceModel
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        dataSet.updateMetadata((String) ((Parameters) this.parameters).output.value(), observationMetadata -> {
            observationMetadata.setType(NDArray.class);
            observationMetadata.setDimension(this.priors.length);
        });
    }
}
