package hex.naivebayes;

import com.sun.jna.platform.win32.COM.tlb.imp.TlbConst;
import feedzai.jetty9.shaded.org.eclipse.jetty.util.security.Constraint;
import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.genmodel.GenModel;
import hex.schemas.NaiveBayesModelV3;
import hex.util.EffectiveParametersUtils;
import org.apache.commons.math3.geometry.VectorFormat;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.H2O;
import water.Key;
import water.api.schemas3.ModelSchemaV3;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.SBPrintStream;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/naivebayes/NaiveBayesModel.class */
public class NaiveBayesModel extends Model<NaiveBayesModel, NaiveBayesParameters, NaiveBayesOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/naivebayes/NaiveBayesModel$NaiveBayesOutput.class */
    public static class NaiveBayesOutput extends Model.Output {
        public TwoDimTable _apriori;
        public double[] _apriori_raw;
        public TwoDimTable[] _pcond;
        public double[][][] _pcond_raw;
        public int[] _rescnt;
        public String[] _levels;
        public int _ncats;

        public NaiveBayesOutput(NaiveBayes naiveBayes) {
            super(naiveBayes);
        }
    }

    /* loaded from: input_file:hex/naivebayes/NaiveBayesModel$NaiveBayesParameters.class */
    public static class NaiveBayesParameters extends Model.Parameters {
        public double _laplace = CMAESOptimizer.DEFAULT_STOPFITNESS;
        public double _eps_sdev = CMAESOptimizer.DEFAULT_STOPFITNESS;
        public double _min_sdev = 0.001d;
        public double _eps_prob = CMAESOptimizer.DEFAULT_STOPFITNESS;
        public double _min_prob = 0.001d;
        public boolean _compute_metrics = true;

        @Override // hex.Model.Parameters
        public String algoName() {
            return "NaiveBayes";
        }

        @Override // hex.Model.Parameters
        public String fullName() {
            return "Naive Bayes";
        }

        @Override // hex.Model.Parameters
        public String javaName() {
            return NaiveBayesModel.class.getName();
        }

        @Override // hex.Model.Parameters
        public long progressUnits() {
            return 6L;
        }
    }

    public NaiveBayesModel(Key key, NaiveBayesParameters naiveBayesParameters, NaiveBayesOutput naiveBayesOutput) {
        super(key, naiveBayesParameters, naiveBayesOutput);
    }

    @Override // hex.Model
    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
    }

    public ModelSchemaV3 schema() {
        return new NaiveBayesModelV3();
    }

    @Override // hex.Model
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        switch (((NaiveBayesOutput) this._output).getModelCategory()) {
            case Binomial:
                return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
            case Multinomial:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(strArr.length, strArr, ((NaiveBayesParameters) this._parms)._auc_type);
            default:
                throw H2O.unimpl();
        }
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[((NaiveBayesOutput) this._output)._levels.length];
        if (!$assertionsDisabled && dArr2.length < ((NaiveBayesOutput) this._output)._levels.length + 1) {
            throw new AssertionError();
        }
        for (int i = 0; i < ((NaiveBayesOutput) this._output)._levels.length; i++) {
            dArr3[i] = Math.log(((NaiveBayesOutput) this._output)._apriori_raw[i]);
            for (int i2 = 0; i2 < ((NaiveBayesOutput) this._output)._ncats; i2++) {
                if (!Double.isNaN(dArr[i2])) {
                    int i3 = (int) dArr[i2];
                    double length = i3 < ((NaiveBayesOutput) this._output)._pcond_raw[i2][i].length ? ((NaiveBayesOutput) this._output)._pcond_raw[i2][i][i3] : ((NaiveBayesParameters) this._parms)._laplace / (((NaiveBayesOutput) this._output)._rescnt[i] + (((NaiveBayesParameters) this._parms)._laplace * ((NaiveBayesOutput) this._output)._domains[i2].length));
                    int i4 = i;
                    dArr3[i4] = dArr3[i4] + Math.log(length <= ((NaiveBayesParameters) this._parms)._eps_prob ? ((NaiveBayesParameters) this._parms)._min_prob : length);
                }
            }
            for (int i5 = ((NaiveBayesOutput) this._output)._ncats; i5 < dArr.length; i5++) {
                if (!Double.isNaN(dArr[i5])) {
                    double d = dArr[i5];
                    double d2 = Double.isNaN(((NaiveBayesOutput) this._output)._pcond_raw[i5][i][0]) ? CMAESOptimizer.DEFAULT_STOPFITNESS : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][0];
                    double d3 = Double.isNaN(((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1]) ? 1.0d : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1] <= ((NaiveBayesParameters) this._parms)._eps_sdev ? ((NaiveBayesParameters) this._parms)._min_sdev : ((NaiveBayesOutput) this._output)._pcond_raw[i5][i][1];
                    double exp = Math.exp((-((d - d2) * (d - d2))) / ((2.0d * d3) * d3)) / (d3 * Math.sqrt(6.283185307179586d));
                    int i6 = i;
                    dArr3[i6] = dArr3[i6] + Math.log(exp <= ((NaiveBayesParameters) this._parms)._eps_prob ? ((NaiveBayesParameters) this._parms)._min_prob : exp);
                }
            }
        }
        for (int i7 = 0; i7 < dArr3.length; i7++) {
            double d4 = 0.0d;
            for (double d5 : dArr3) {
                d4 += Math.exp(d5 - dArr3[i7]);
            }
            dArr2[i7 + 1] = 1.0d / d4;
        }
        dArr2[0] = GenModel.getPrediction(dArr2, ((NaiveBayesOutput) this._output)._priorClassDist, dArr, defaultThreshold());
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, hex.DefaultPojoWriter
    public SBPrintStream toJavaInit(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
        SBPrintStream javaInit = super.toJavaInit(sBPrintStream, codeGeneratorPipeline);
        javaInit.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
        javaInit.ip("public int nfeatures() { return " + ((NaiveBayesOutput) this._output).nfeatures() + "; }").nl();
        javaInit.ip("public int nclasses() { return " + ((NaiveBayesOutput) this._output).nclasses() + "; }").nl();
        final String javaId = JCodeGen.toJavaId(this._key.toString());
        codeGeneratorPipeline.add(new CodeGenerator() { // from class: hex.naivebayes.NaiveBayesModel.1
            @Override // water.codegen.CodeGenerator
            public void generate(JCodeSB jCodeSB) {
                JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_RESCNT", ((NaiveBayesOutput) NaiveBayesModel.this._output)._rescnt, "Count of categorical levels in response.");
                JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_APRIORI", ((NaiveBayesOutput) NaiveBayesModel.this._output)._apriori_raw, "Apriori class distribution of the response.");
                JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_PCOND", ((NaiveBayesOutput) NaiveBayesModel.this._output)._pcond_raw, "Conditional probability of predictors.");
                double[] dArr = null;
                if (((NaiveBayesOutput) NaiveBayesModel.this._output)._ncats > 0) {
                    dArr = new double[((NaiveBayesOutput) NaiveBayesModel.this._output)._ncats];
                    for (int i = 0; i < ((NaiveBayesOutput) NaiveBayesModel.this._output)._ncats; i++) {
                        dArr[i] = ((NaiveBayesOutput) NaiveBayesModel.this._output)._domains[i].length;
                    }
                }
                JCodeGen.toClassWithArray(jCodeSB, (String) null, javaId + "_DOMLEN", dArr, "Number of unique levels for each categorical predictor.");
            }
        });
        return javaInit;
    }

    @Override // hex.Model, hex.DefaultPojoWriter
    protected void toJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, CodeGeneratorPipeline codeGeneratorPipeline2, boolean z) {
        String javaId = JCodeGen.toJavaId(this._key.toString());
        sBPrintStream.i().p("java.util.Arrays.fill(preds,0);").nl();
        sBPrintStream.i().p("double mean, sdev, prob;").nl();
        sBPrintStream.i().p("double[] nums = new double[" + ((NaiveBayesOutput) this._output)._levels.length + "];").nl();
        sBPrintStream.i().p("for(int i = 0; i < " + ((NaiveBayesOutput) this._output)._levels.length + "; i++) {").nl();
        sBPrintStream.i(1).p("nums[i] = Math.log(").pj(javaId + "_APRIORI", "VALUES").p("[i]);").nl();
        sBPrintStream.i(1).p("for(int j = 0; j < " + ((NaiveBayesOutput) this._output)._ncats + "; j++) {").nl();
        sBPrintStream.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        sBPrintStream.i(2).p("int level = (int)data[j];").nl();
        sBPrintStream.i(2).p("prob = level < ").p(((NaiveBayesOutput) this._output)._pcond_raw.length).p(" ? " + javaId + "_PCOND.VALUES[j][i][level] : ").p(((NaiveBayesParameters) this._parms)._laplace == CMAESOptimizer.DEFAULT_STOPFITNESS ? TlbConst.TYPELIB_MINOR_VERSION_SHELL : ((NaiveBayesParameters) this._parms)._laplace + "/(" + javaId + "_RESCNT.VALUES[i] + " + ((NaiveBayesParameters) this._parms)._laplace + Constraint.ANY_ROLE + javaId + "_DOMLEN.VALUES[j])").p(";").nl();
        sBPrintStream.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters) this._parms)._eps_prob + " ? " + ((NaiveBayesParameters) this._parms)._min_prob + " : prob);").nl();
        sBPrintStream.i(1).p(VectorFormat.DEFAULT_SUFFIX).nl();
        sBPrintStream.i(1).p("for(int j = " + ((NaiveBayesOutput) this._output)._ncats + "; j < data.length; j++) {").nl();
        sBPrintStream.i(2).p("if(Double.isNaN(data[j])) continue;").nl();
        sBPrintStream.i(2).p("mean = Double.isNaN(" + javaId + "_PCOND.VALUES[j][i][0]) ? 0 : " + javaId + "_PCOND.VALUES[j][i][0];").nl();
        sBPrintStream.i(2).p("sdev = Double.isNaN(" + javaId + "_PCOND.VALUES[j][i][1]) ? 1 : (" + javaId + "_PCOND.VALUES[j][i][1] <= " + ((NaiveBayesParameters) this._parms)._eps_sdev + " ? " + ((NaiveBayesParameters) this._parms)._min_sdev + " : " + javaId + "_PCOND.VALUES[j][i][1]);").nl();
        sBPrintStream.i(2).p("prob = Math.exp(-((data[j]-mean)*(data[j]-mean))/(2.*sdev*sdev)) / (sdev*Math.sqrt(2.*Math.PI));").nl();
        sBPrintStream.i(2).p("nums[i] += Math.log(prob <= " + ((NaiveBayesParameters) this._parms)._eps_prob + " ? " + ((NaiveBayesParameters) this._parms)._min_prob + " : prob);").nl();
        sBPrintStream.i(1).p(VectorFormat.DEFAULT_SUFFIX).nl();
        sBPrintStream.i().p(VectorFormat.DEFAULT_SUFFIX).nl();
        sBPrintStream.i().p("double sum;").nl();
        sBPrintStream.i().p("for(int i = 0; i < nums.length; i++) {").nl();
        sBPrintStream.i(1).p("sum = 0;").nl();
        sBPrintStream.i(1).p("for(int j = 0; j < nums.length; j++) {").nl();
        sBPrintStream.i(2).p("sum += Math.exp(nums[j]-nums[i]);").nl();
        sBPrintStream.i(1).p(VectorFormat.DEFAULT_SUFFIX).nl();
        sBPrintStream.i(1).p("preds[i+1] = 1/sum;").nl();
        sBPrintStream.i().p(VectorFormat.DEFAULT_SUFFIX).nl();
        sBPrintStream.i().p("preds[0] = hex.genmodel.GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, " + defaultThreshold() + ");").nl();
    }

    @Override // hex.Model
    protected boolean isFeatureUsedInPredict(int i) {
        for (int i2 = 0; i2 < ((NaiveBayesOutput) this._output)._pcond_raw[i].length; i2++) {
            double d = ((NaiveBayesOutput) this._output)._pcond_raw[i][i2][0];
            for (double d2 : ((NaiveBayesOutput) this._output)._pcond_raw[i][i2]) {
                if (d != d2) {
                    return true;
                }
            }
        }
        return false;
    }

    static {
        $assertionsDisabled = !NaiveBayesModel.class.desiredAssertionStatus();
    }
}
