package org.jpmml.evaluator.naive_bayes;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.util.Precision;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PoissonDistribution;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.DiscretizationUtil;
import org.jpmml.evaluator.DistributionUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.HasParsedValueMapping;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.PMMLElements;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueUtil;
import org.jpmml.evaluator.VerificationUtil;
import org.jpmml.evaluator.XPathUtil;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.3.11.jar:org/jpmml/evaluator/naive_bayes/NaiveBayesModelEvaluator.class */
public class NaiveBayesModelEvaluator extends ModelEvaluator<NaiveBayesModel> {
    private transient List<BayesInput> bayesInputs;
    private transient Map<FieldName, Map<String, Double>> fieldCountSums;
    private static final LoadingCache<NaiveBayesModel, List<BayesInput>> bayesInputCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, List<BayesInput>>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.2
        @Override // com.google.common.cache.CacheLoader
        public List<BayesInput> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableList.copyOf((Collection) NaiveBayesModelEvaluator.parseBayesInputs(naiveBayesModel));
        }
    });
    private static final LoadingCache<NaiveBayesModel, Map<FieldName, Map<String, Double>>> fieldCountSumCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, Map<FieldName, Map<String, Double>>>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.3
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, Map<String, Double>> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableMap.copyOf(NaiveBayesModelEvaluator.calculateFieldCountSums(naiveBayesModel));
        }
    });

    public NaiveBayesModelEvaluator(PMML pmml) {
        this(pmml, (NaiveBayesModel) selectModel(pmml, NaiveBayesModel.class));
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
        this.bayesInputs = null;
        this.fieldCountSums = null;
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (bayesInputs == null) {
            throw new MissingElementException(naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESINPUTS);
        }
        if (!bayesInputs.hasBayesInputs() && !bayesInputs.hasExtensions()) {
            throw new MissingElementException(bayesInputs, PMMLElements.BAYESINPUTS_BAYESINPUTS);
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        if (bayesOutput == null) {
            throw new MissingElementException(naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESOUTPUT);
        }
        TargetValueCounts targetValueCounts = bayesOutput.getTargetValueCounts();
        if (targetValueCounts == null) {
            throw new MissingElementException(bayesOutput, PMMLElements.BAYESOUTPUT_TARGETVALUECOUNTS);
        }
        if (!targetValueCounts.hasTargetValueCounts()) {
            throw new MissingElementException(targetValueCounts, PMMLElements.TARGETVALUECOUNTS_TARGETVALUECOUNTS);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Naive Bayes model";
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        NaiveBayesModel ensureScorableModel = ensureScorableModel();
        MathContext mathContext = ensureScorableModel.getMathContext();
        switch (mathContext) {
            case DOUBLE:
                ValueFactory<?> valueFactory = getValueFactory();
                MiningFunction miningFunction = ensureScorableModel.getMiningFunction();
                switch (miningFunction) {
                    case CLASSIFICATION:
                        return OutputUtil.evaluate(evaluateClassification(valueFactory, modelEvaluationContext), modelEvaluationContext);
                    case ASSOCIATION_RULES:
                    case SEQUENCES:
                    case REGRESSION:
                    case CLUSTERING:
                    case TIME_SERIES:
                    case MIXED:
                        throw new InvalidAttributeException(ensureScorableModel, miningFunction);
                    default:
                        throw new UnsupportedAttributeException(ensureScorableModel, miningFunction);
                }
            default:
                throw new UnsupportedAttributeException(ensureScorableModel, mathContext);
        }
    }

    private Map<FieldName, ? extends Classification<Double>> evaluateClassification(final ValueFactory<Double> valueFactory, EvaluationContext evaluationContext) {
        NaiveBayesModel model = getModel();
        TargetField targetField = getTargetField();
        double threshold = model.getThreshold();
        Map<FieldName, Map<String, Double>> fieldCountSums = getFieldCountSums();
        ProbabilityMap<String, Double> probabilityMap = new ProbabilityMap<String, Double>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.1
            @Override // org.jpmml.evaluator.ValueMap
            public ValueFactory<Double> getValueFactory() {
                return valueFactory;
            }

            @Override // org.jpmml.evaluator.naive_bayes.ProbabilityMap
            public void multiply(String str, double d) {
                ensureValue(str).add2(Math.log(d));
            }
        };
        for (BayesInput bayesInput : getBayesInputs()) {
            FieldName fieldName = bayesInput.getFieldName();
            if (fieldName == null) {
                throw new MissingAttributeException(bayesInput, PMMLAttributes.BAYESINPUT_FIELDNAME);
            }
            FieldValue evaluate = evaluationContext.evaluate(fieldName);
            if (evaluate != null) {
                TargetValueStats targetValueStats = getTargetValueStats(bayesInput);
                if (targetValueStats != null) {
                    calculateContinuousProbabilities(probabilityMap, targetValueStats, threshold, evaluate);
                } else {
                    DerivedField derivedField = bayesInput.getDerivedField();
                    if (derivedField != null) {
                        evaluate = discretize(derivedField, evaluate);
                        if (evaluate == null) {
                        }
                    }
                    Map<String, Double> map = fieldCountSums.get(fieldName);
                    TargetValueCounts targetValueCounts = getTargetValueCounts(bayesInput, evaluate);
                    if (targetValueCounts != null) {
                        calculateDiscreteProbabilities(probabilityMap, targetValueCounts, threshold, map);
                    }
                }
            }
        }
        BayesOutput bayesOutput = model.getBayesOutput();
        FieldName fieldName2 = bayesOutput.getFieldName();
        if (fieldName2 == null) {
            throw new MissingAttributeException(bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELDNAME);
        }
        if (fieldName2 != null && !Objects.equals(targetField.getName(), fieldName2)) {
            throw new InvalidAttributeException(bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELDNAME, fieldName2);
        }
        calculatePriorProbabilities(probabilityMap, bayesOutput.getTargetValueCounts());
        ValueUtil.normalizeSoftMax(probabilityMap);
        return TargetUtil.evaluateClassification(targetField, new ProbabilityDistribution(probabilityMap));
    }

    private FieldValue discretize(DerivedField derivedField, FieldValue fieldValue) {
        Expression ensureExpression = ExpressionUtil.ensureExpression(derivedField);
        if (!(ensureExpression instanceof Discretize)) {
            throw new MisplacedElementException(ensureExpression);
        }
        FieldValue discretize = DiscretizationUtil.discretize((Discretize) ensureExpression, fieldValue);
        if (discretize == null) {
            return null;
        }
        return FieldValueUtil.refine(derivedField, discretize);
    }

    private void calculateContinuousProbabilities(ProbabilityMap<String, Double> probabilityMap, TargetValueStats targetValueStats, double d, FieldValue fieldValue) {
        Number asNumber = fieldValue.asNumber();
        Iterator<TargetValueStat> it = targetValueStats.iterator();
        while (it.hasNext()) {
            TargetValueStat next = it.next();
            String value = next.getValue();
            if (value == null) {
                throw new MissingAttributeException(next, PMMLAttributes.TARGETVALUESTAT_VALUE);
            }
            ContinuousDistribution continuousDistribution = next.getContinuousDistribution();
            if (continuousDistribution == null) {
                throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(next.getClass()) + "/<ContinuousDistribution>"), next);
            }
            if (!(continuousDistribution instanceof GaussianDistribution) && !(continuousDistribution instanceof PoissonDistribution)) {
                throw new MisplacedElementException(continuousDistribution);
            }
            if (!DistributionUtil.isNoOp(continuousDistribution)) {
                probabilityMap.multiply(value, Math.max(DistributionUtil.probability(continuousDistribution, asNumber), d));
            }
        }
    }

    private void calculateDiscreteProbabilities(ProbabilityMap<String, Double> probabilityMap, TargetValueCounts targetValueCounts, double d, Map<String, Double> map) {
        Iterator<TargetValueCount> it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount next = it.next();
            String value = next.getValue();
            if (value == null) {
                throw new MissingAttributeException(next, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            double count = next.getCount();
            probabilityMap.multiply(value, VerificationUtil.isZero(Double.valueOf(count), Precision.EPSILON) ? d : count / map.get(value).doubleValue());
        }
    }

    private void calculatePriorProbabilities(ProbabilityMap<String, Double> probabilityMap, TargetValueCounts targetValueCounts) {
        Iterator<TargetValueCount> it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount next = it.next();
            String value = next.getValue();
            if (value == null) {
                throw new MissingAttributeException(next, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            probabilityMap.multiply(value, next.getCount());
        }
    }

    protected List<BayesInput> getBayesInputs() {
        if (this.bayesInputs == null) {
            this.bayesInputs = (List) getValue(bayesInputCache);
        }
        return this.bayesInputs;
    }

    protected Map<FieldName, Map<String, Double>> getFieldCountSums() {
        if (this.fieldCountSums == null) {
            this.fieldCountSums = (Map) getValue(fieldCountSumCache);
        }
        return this.fieldCountSums;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<FieldName, Map<String, Double>> calculateFieldCountSums(NaiveBayesModel naiveBayesModel) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (BayesInput bayesInput : (List) CacheUtil.getValue(naiveBayesModel, bayesInputCache)) {
            FieldName fieldName = bayesInput.getFieldName();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            Iterator<PairCounts> it = bayesInput.getPairCounts().iterator();
            while (it.hasNext()) {
                Iterator<TargetValueCount> it2 = it.next().getTargetValueCounts().iterator();
                while (it2.hasNext()) {
                    TargetValueCount next = it2.next();
                    Double d = (Double) linkedHashMap2.get(next.getValue());
                    if (d == null) {
                        d = Double.valueOf(0.0d);
                    }
                    linkedHashMap2.put(next.getValue(), Double.valueOf(d.doubleValue() + next.getCount()));
                }
            }
            linkedHashMap.put(fieldName, linkedHashMap2);
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<BayesInput> parseBayesInputs(NaiveBayesModel naiveBayesModel) {
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (!bayesInputs.hasExtensions()) {
            return bayesInputs.getBayesInputs();
        }
        ArrayList arrayList = new ArrayList(bayesInputs.getBayesInputs());
        Iterator<Extension> it = bayesInputs.getExtensions().iterator();
        while (it.hasNext()) {
            for (Object obj : it.next().getContent()) {
                if (obj instanceof BayesInput) {
                    arrayList.add((BayesInput) obj);
                }
            }
        }
        return arrayList;
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue fieldValue) {
        if (bayesInput instanceof HasParsedValueMapping) {
            return (TargetValueCounts) fieldValue.getMapping((HasParsedValueMapping) bayesInput);
        }
        for (PairCounts pairCounts : bayesInput.getPairCounts()) {
            String value = pairCounts.getValue();
            if (value == null) {
                throw new MissingAttributeException(pairCounts, PMMLAttributes.PAIRCOUNTS_VALUE);
            }
            if (fieldValue.equalsString(value)) {
                TargetValueCounts targetValueCounts = pairCounts.getTargetValueCounts();
                if (targetValueCounts == null) {
                    throw new MissingElementException(pairCounts, PMMLElements.PAIRCOUNTS_TARGETVALUECOUNTS);
                }
                return targetValueCounts;
            }
        }
        return null;
    }
}
