package sklearn2pmml.tree;

import chaid.Column;
import chaid.Split;
import com.google.common.math.DoubleMath;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.False;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreFrequency;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import treelib.Tree;

/* loaded from: input_file:sklearn2pmml/tree/CHAIDUtil.class */
public class CHAIDUtil {
    private CHAIDUtil() {
    }

    public static TreeModel encodeModel(MiningFunction miningFunction, Tree tree, Schema schema) {
        return new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, tree.selectRoot(), tree, new PredicateManager(), schema));
    }

    private static Node encodeNode(Predicate predicate, treelib.Node node, Tree tree, PredicateManager predicateManager, Schema schema) {
        ClassifierNode classifierNode;
        Predicate predicate2;
        ContinuousLabel label = schema.getLabel();
        chaid.Node node2 = (chaid.Node) node.getTag(chaid.Node.class);
        List selectSuccessors = node.selectSuccessors(tree);
        Column depV = node2.getDepV();
        List<Integer> indices = node2.getIndices();
        Split split = node2.getSplit();
        List<? extends Number> arr = depV.getArr();
        ClassDictUtil.checkSize(new Collection[]{arr, indices});
        Integer columnId = split.getColumnId();
        List<List<Integer>> splits = split.getSplits();
        List<List<Object>> splitMap = split.getSplitMap();
        ClassDictUtil.checkSize(new Collection[]{selectSuccessors, splits, splitMap});
        Comparator<treelib.Node> comparator = new Comparator<treelib.Node>() { // from class: sklearn2pmml.tree.CHAIDUtil.1
            @Override // java.util.Comparator
            public int compare(treelib.Node node3, treelib.Node node4) {
                chaid.Node node5 = (chaid.Node) node3.getTag(chaid.Node.class);
                chaid.Node node6 = (chaid.Node) node4.getTag(chaid.Node.class);
                return Integer.compare(node5.getIndices().size(), node6.getIndices().size());
            }
        };
        if (selectSuccessors.isEmpty()) {
            classifierNode = label instanceof CategoricalLabel ? new ClassifierNode((Object) null, predicate) : new CountingLeafNode((Object) null, predicate);
        } else {
            CategoricalFeature feature = schema.getFeature(columnId.intValue());
            List values = feature.getValues();
            classifierNode = label instanceof CategoricalLabel ? new ClassifierNode((Object) null, predicate) : new CountingBranchNode((Object) null, predicate);
            LinkedHashSet linkedHashSet = new LinkedHashSet(values);
            for (int i = 0; i < selectSuccessors.size(); i++) {
                List<Integer> list = splits.get(i);
                List<Object> list2 = splitMap.get(i);
                ClassDictUtil.checkSize(new Collection[]{list, list2});
                for (int i2 = 0; i2 < list.size(); i2++) {
                    Integer num = list.get(i2);
                    Object obj = list2.get(i2);
                    if (!isMissing(num, obj)) {
                        removeCategory(linkedHashSet, obj);
                    }
                }
            }
            treelib.Node node3 = null;
            if (!linkedHashSet.isEmpty()) {
                for (int i3 = 0; i3 < selectSuccessors.size(); i3++) {
                    treelib.Node node4 = (treelib.Node) selectSuccessors.get(i3);
                    if (node3 == null || comparator.compare(node4, node3) >= 0) {
                        node3 = node4;
                    }
                }
            }
            for (int i4 = 0; i4 < selectSuccessors.size(); i4++) {
                treelib.Node node5 = (treelib.Node) selectSuccessors.get(i4);
                List<Integer> list3 = splits.get(i4);
                List<Object> list4 = splitMap.get(i4);
                ArrayList arrayList = new ArrayList();
                boolean z = false;
                for (int i5 = 0; i5 < list3.size(); i5++) {
                    Integer num2 = list3.get(i5);
                    Object obj2 = list4.get(i5);
                    if (isMissing(num2, obj2)) {
                        z = true;
                    } else {
                        arrayList.add(selectCategory(values, obj2));
                    }
                }
                if (Objects.equals(node5, node3)) {
                    arrayList.addAll(linkedHashSet);
                }
                if (arrayList.isEmpty()) {
                    predicate2 = False.INSTANCE;
                    if (z) {
                        predicate2 = predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, (Object) null);
                    }
                } else {
                    predicate2 = predicateManager.createPredicate(feature, arrayList);
                    if (z) {
                        predicate2 = predicateManager.createCompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, new Predicate[]{predicate2, predicateManager.createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, (Object) null)});
                    }
                }
                classifierNode.addNodes(encodeNode(predicate2, node5, tree, predicateManager, schema));
            }
        }
        classifierNode.setId(node.getIdentifier()).setRecordCount(Integer.valueOf(arr.size()));
        if (label instanceof ContinuousLabel) {
            classifierNode.setScore(Double.valueOf(DoubleMath.mean(arr)));
        } else {
            if (!(label instanceof CategoricalLabel)) {
                throw new IllegalArgumentException();
            }
            CategoricalLabel categoricalLabel = (CategoricalLabel) label;
            Map map = (Map) arr.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
            List scoreDistributions = classifierNode.getScoreDistributions();
            Long l = null;
            for (Map.Entry entry : map.entrySet()) {
                Object value = categoricalLabel.getValue(((Integer) entry.getKey()).intValue());
                Long l2 = (Long) entry.getValue();
                if (l == null || l.compareTo(l2) < 0) {
                    l = l2;
                    classifierNode.setScore(value);
                }
                scoreDistributions.add(new ScoreFrequency(value, l2));
            }
        }
        return classifierNode;
    }

    private static boolean isMissing(Integer num, Object obj) {
        return num.intValue() == -1 || obj == null;
    }

    private static void removeCategory(Collection<?> collection, Object obj) {
        Iterator<?> it = collection.iterator();
        while (it.hasNext()) {
            if (equals(it.next(), obj)) {
                it.remove();
                return;
            }
        }
        throw new IllegalArgumentException();
    }

    private static Object selectCategory(Collection<?> collection, Object obj) {
        for (Object obj2 : collection) {
            if (equals(obj2, obj)) {
                return obj2;
            }
        }
        throw new IllegalArgumentException();
    }

    private static boolean equals(Object obj, Object obj2) {
        return ((obj instanceof Number) && (obj2 instanceof Number)) ? Double.compare(((Number) obj).doubleValue(), ((Number) obj2).doubleValue()) == 0 : Objects.equals(obj, obj2);
    }
}
