package org.jpmml.evaluator.tree;

import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLAttributes;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UndefinedResultException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.model.InvalidAttributeException;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.model.UnsupportedElementException;
import org.jpmml.model.visitors.AbstractVisitor;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.6.3.jar:org/jpmml/evaluator/tree/ComplexTreeModelEvaluator.class */
public class ComplexTreeModelEvaluator extends TreeModelEvaluator implements HasNodeRegistry {
    private BiMap<String, Node> entityRegistry;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.6.3.jar:org/jpmml/evaluator/tree/ComplexTreeModelEvaluator$Trail.class */
    public static class Trail {
        private Node lastPrediction = null;
        private Node result = null;
        private int missingLevels = 0;

        public void push(Node node) {
            setLastPrediction(node);
        }

        public Trail selectNull() {
            setResult(null);
            return this;
        }

        public Trail selectNode(Node node) {
            setResult(node);
            return this;
        }

        public Trail selectLastPrediction() {
            setResult(getLastPrediction());
            return this;
        }

        public Node getResult() {
            return this.result;
        }

        private void setResult(Node node) {
            this.result = node;
        }

        public Node getLastPrediction() {
            if (this.lastPrediction == null) {
                throw new EvaluationException("Empty trail");
            }
            return this.lastPrediction;
        }

        private void setLastPrediction(Node node) {
            this.lastPrediction = node;
        }

        public void addMissingLevel() {
            setMissingLevels(getMissingLevels() + 1);
        }

        public int getMissingLevels() {
            return this.missingLevels;
        }

        private void setMissingLevels(int i) {
            this.missingLevels = i;
        }
    }

    private ComplexTreeModelEvaluator() {
        this.entityRegistry = ImmutableBiMap.of();
    }

    public ComplexTreeModelEvaluator(PMML pmml) {
        this(pmml, (TreeModel) PMMLUtil.findModel(pmml, TreeModel.class));
    }

    public ComplexTreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
        this.entityRegistry = ImmutableBiMap.of();
        this.entityRegistry = ImmutableBiMap.copyOf((Map) EntityUtil.buildBiMap(collectNodes(treeModel)));
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Node> getEntityRegistry() {
        return this.entityRegistry;
    }

    @Override // org.jpmml.evaluator.tree.HasNodeRegistry
    public List<Node> getPath(String str) {
        return getPath(resolveNode(str));
    }

    @Override // org.jpmml.evaluator.tree.HasNodeRegistry
    public List<Node> getPathBetween(String str, String str2) {
        return getPathBetween(resolveNode(str), resolveNode(str2));
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<String, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        TargetField targetField = getTargetField();
        Node evaluateTree = evaluateTree(new Trail(), evaluationContext);
        return evaluateTree == null ? TargetUtil.evaluateRegressionDefault(valueFactory, targetField) : TargetUtil.evaluateRegression(targetField, createNodeScore(valueFactory, targetField, evaluateTree));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<String, ?> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        TargetField targetField = getTargetField();
        Trail trail = new Trail();
        Node evaluateTree = evaluateTree(trail, evaluationContext);
        if (evaluateTree == null) {
            return TargetUtil.evaluateClassificationDefault(valueFactory, targetField);
        }
        if (!evaluateTree.hasScoreDistributions()) {
            return TargetUtil.evaluateVote(targetField, createNodeVote(evaluateTree));
        }
        double d = 1.0d;
        int missingLevels = trail.getMissingLevels();
        if (missingLevels > 0) {
            d = treeModel.getMissingValuePenalty().doubleValue();
            if (missingLevels > 1) {
                d = Math.pow(d, missingLevels);
            }
        }
        return TargetUtil.evaluateClassification(targetField, createNodeScoreDistribution(valueFactory, evaluateTree, d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Node evaluateTree(Trail trail, EvaluationContext evaluationContext) {
        Node requireNode = ((TreeModel) getModel()).requireNode();
        Boolean evaluateNode = evaluateNode(trail, requireNode, evaluationContext);
        if (evaluateNode == null || !evaluateNode.booleanValue()) {
            return null;
        }
        return handleTrue(trail, requireNode, evaluationContext).getResult();
    }

    private Boolean evaluateNode(Trail trail, Node node, EvaluationContext evaluationContext) {
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedElementException(embeddedModel);
        }
        Predicate requirePredicate = node.requirePredicate();
        if (!(requirePredicate instanceof CompoundPredicate)) {
            return PredicateUtil.evaluate(requirePredicate, evaluationContext);
        }
        PredicateUtil.CompoundPredicateResult evaluateCompoundPredicateInternal = PredicateUtil.evaluateCompoundPredicateInternal((CompoundPredicate) requirePredicate, evaluationContext);
        if (evaluateCompoundPredicateInternal.isAlternative()) {
            trail.addMissingLevel();
        }
        return evaluateCompoundPredicateInternal.getResult();
    }

    private Trail handleTrue(Trail trail, Node node, EvaluationContext evaluationContext) {
        if (!node.hasNodes()) {
            return trail.selectNode(node);
        }
        trail.push(node);
        List<Node> nodes = node.getNodes();
        int size = nodes.size();
        for (int i = 0; i < size; i++) {
            Node node2 = nodes.get(i);
            Boolean evaluateNode = evaluateNode(trail, node2, evaluationContext);
            if (evaluateNode == null) {
                Trail handleMissingValue = handleMissingValue(trail, node, node2, evaluationContext);
                if (handleMissingValue != null) {
                    return handleMissingValue;
                }
            } else if (evaluateNode.booleanValue()) {
                return handleTrue(trail, node2, evaluationContext);
            }
        }
        return handleNoTrueChild(trail);
    }

    private Trail handleDefaultChild(Trail trail, Node node, EvaluationContext evaluationContext) {
        Node findDefaultChild = findDefaultChild(node);
        trail.addMissingLevel();
        return handleTrue(trail, findDefaultChild, evaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Trail handleNoTrueChild(Trail trail) {
        TreeModel treeModel = (TreeModel) getModel();
        TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION:
                return trail.selectNull();
            case RETURN_LAST_PREDICTION:
                return trail.getLastPrediction().hasScore() ? trail.selectLastPrediction() : trail.selectNull();
            default:
                throw new UnsupportedAttributeException(treeModel, noTrueChildStrategy);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Trail handleMissingValue(Trail trail, Node node, Node node2, EvaluationContext evaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        TreeModel.MissingValueStrategy missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION:
                return trail.selectNull();
            case LAST_PREDICTION:
                return trail.selectLastPrediction();
            case DEFAULT_CHILD:
                return handleDefaultChild(trail, node, evaluationContext);
            case NONE:
                return null;
            default:
                throw new UnsupportedAttributeException(treeModel, missingValueStrategy);
        }
    }

    private <V extends Number> NodeScore<V> createNodeScore(ValueFactory<V> valueFactory, TargetField targetField, Node node) {
        Object requireScore = node.requireScore();
        return (NodeScore<V>) new NodeScore<V>(TargetUtil.evaluateRegressionInternal(targetField, requireScore instanceof Number ? valueFactory.newValue((Number) requireScore) : valueFactory.newValue((String) requireScore)), node) { // from class: org.jpmml.evaluator.tree.ComplexTreeModelEvaluator.1
            @Override // org.jpmml.evaluator.HasEntityRegistry
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override // org.jpmml.evaluator.tree.HasDecisionPath
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(getNode());
            }
        };
    }

    private NodeVote createNodeVote(Node node) {
        return new NodeVote(node) { // from class: org.jpmml.evaluator.tree.ComplexTreeModelEvaluator.2
            @Override // org.jpmml.evaluator.HasEntityRegistry
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override // org.jpmml.evaluator.tree.HasDecisionPath
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(getNode());
            }
        };
    }

    private <V extends Number> NodeScoreDistribution<V> createNodeScoreDistribution(ValueFactory<V> valueFactory, Node node, double d) {
        Value<V> newValue;
        List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
        NodeScoreDistribution<V> nodeScoreDistribution = (NodeScoreDistribution<V>) new NodeScoreDistribution<V>(new ValueMap(2 * scoreDistributions.size()), node) { // from class: org.jpmml.evaluator.tree.ComplexTreeModelEvaluator.3
            @Override // org.jpmml.evaluator.HasEntityRegistry
            public BiMap<String, Node> getEntityRegistry() {
                return ComplexTreeModelEvaluator.this.getEntityRegistry();
            }

            @Override // org.jpmml.evaluator.tree.HasDecisionPath
            public List<Node> getDecisionPath() {
                return ComplexTreeModelEvaluator.this.getPath(getNode());
            }
        };
        Value<V> newValue2 = valueFactory.newValue();
        boolean z = false;
        int size = scoreDistributions.size();
        for (int i = 0; i < size; i++) {
            ScoreDistribution scoreDistribution = scoreDistributions.get(i);
            if (i == 0) {
                z = scoreDistribution.getProbability() != null;
            }
            if (z) {
                Number requireProbability = scoreDistribution.requireProbability();
                if (requireProbability.doubleValue() < 0.0d || requireProbability.doubleValue() > 1.0d) {
                    throw new InvalidAttributeException(scoreDistribution, PMMLAttributes.SCOREDISTRIBUTION_PROBABILITY, requireProbability);
                }
                newValue2.add2(requireProbability);
                newValue = valueFactory.newValue(requireProbability);
            } else {
                Number requireRecordCount = scoreDistribution.requireRecordCount();
                newValue2.add2(requireRecordCount);
                newValue = valueFactory.newValue(requireRecordCount);
            }
            Object requireValue = scoreDistribution.requireValue();
            nodeScoreDistribution.put(requireValue, newValue);
            Number confidence = scoreDistribution.getConfidence();
            if (confidence != null) {
                nodeScoreDistribution.putConfidence(requireValue, valueFactory.newValue(confidence).multiply2(Double.valueOf(d)));
            }
        }
        if (!newValue2.isOne()) {
            ValueMap<Object, V> values = nodeScoreDistribution.getValues();
            if (newValue2.isZero()) {
                throw new UndefinedResultException();
            }
            Iterator<Value<V>> it = values.iterator();
            while (it.hasNext()) {
                it.next().divide((Value<? extends Number>) newValue2);
            }
        }
        return nodeScoreDistribution;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    public List<Node> getPath(Node node) {
        return getPathBetween(((TreeModel) getModel()).requireNode(), node);
    }

    private List<Node> getPathBetween(Node node, final Node node2) {
        PathFinder pathFinder = new PathFinder() { // from class: org.jpmml.evaluator.tree.ComplexTreeModelEvaluator.4
            @Override // java.util.function.Predicate
            public boolean test(Node node3) {
                return Objects.equals(node2, node3);
            }
        };
        pathFinder.applyTo(node);
        return pathFinder.getPath();
    }

    private Node resolveNode(String str) {
        Node node = getEntityRegistry().get(str);
        if (node == null) {
            throw new IllegalArgumentException(str);
        }
        return node;
    }

    private static List<Node> collectNodes(TreeModel treeModel) {
        final ArrayList arrayList = new ArrayList();
        new AbstractVisitor() { // from class: org.jpmml.evaluator.tree.ComplexTreeModelEvaluator.5
            @Override // org.jpmml.model.visitors.AbstractVisitor, org.dmg.pmml.Visitor
            public VisitorAction visit(Node node) {
                arrayList.add(node);
                return super.visit(node);
            }
        }.applyTo(treeModel);
        return arrayList;
    }
}
