package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBObject;
import com.devsmart.ubjson.UBValueFactory;
import com.google.common.primitives.Ints;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/xgboost/RegTree.class */
public class RegTree implements BinaryLoadable, JSONLoadable, UBJSONLoadable {
    private int num_roots;
    private int num_nodes;
    private int num_deleted;
    private int max_depth;
    private int num_feature;
    private int size_leaf_vector;
    private Node[] nodes;
    private NodeStat[] stats;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.xgboost.RegTree$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/xgboost/RegTree$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.INTEGER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.xgboost.BinaryLoadable
    public void loadBinary(XGBoostDataInput xGBoostDataInput) throws IOException {
        this.num_roots = xGBoostDataInput.readInt();
        this.num_nodes = xGBoostDataInput.readInt();
        this.num_deleted = xGBoostDataInput.readInt();
        this.max_depth = xGBoostDataInput.readInt();
        this.num_feature = xGBoostDataInput.readInt();
        this.size_leaf_vector = xGBoostDataInput.readInt();
        xGBoostDataInput.readReserved(31);
        this.nodes = (Node[]) xGBoostDataInput.readObjectArray(BinaryNode.class, this.num_nodes);
        this.stats = (NodeStat[]) xGBoostDataInput.readObjectArray(BinaryNodeStat.class, this.num_nodes);
    }

    @Override // org.jpmml.xgboost.JSONLoadable
    public void loadJSON(JsonObject jsonObject) {
        loadUBJSON(GsonUtil.toUBValue(jsonObject).asObject());
    }

    @Override // org.jpmml.xgboost.UBJSONLoadable
    public void loadUBJSON(UBObject uBObject) {
        UBObject asObject = uBObject.get("tree_param").asObject();
        this.num_nodes = asObject.get("num_nodes").asInt();
        this.num_deleted = asObject.get("num_deleted").asInt();
        this.num_feature = asObject.get("num_feature").asInt();
        this.size_leaf_vector = asObject.get("size_leaf_vector").asInt();
        int[] intArray = UBJSONUtil.toIntArray(uBObject.get("parents"));
        int[] intArray2 = UBJSONUtil.toIntArray(uBObject.get("left_children"));
        int[] intArray3 = UBJSONUtil.toIntArray(uBObject.get("right_children"));
        boolean[] booleanArray = UBJSONUtil.toBooleanArray(uBObject.get("default_left"));
        int[] intArray4 = UBJSONUtil.toIntArray(uBObject.get("split_indices"));
        int[] intArray5 = UBJSONUtil.toIntArray(uBObject.get("split_type"));
        float[] floatArray = UBJSONUtil.toFloatArray(uBObject.get("split_conditions"));
        boolean contains = Ints.contains(intArray5, 1);
        this.nodes = new Node[this.num_nodes];
        for (int i = 0; i < this.num_nodes; i++) {
            UBObject createObject = UBValueFactory.createObject();
            createObject.put("parent", UBValueFactory.createInt(intArray[i]));
            createObject.put("left_child", UBValueFactory.createInt(intArray2[i]));
            createObject.put("right_child", UBValueFactory.createInt(intArray3[i]));
            createObject.put("default_left", UBValueFactory.createBool(booleanArray[i]));
            createObject.put("split_index", UBValueFactory.createInt(intArray4[i]));
            createObject.put("split_type", UBValueFactory.createInt(intArray5[i]));
            createObject.put("split_condition", UBValueFactory.createFloat32(floatArray[i]));
            this.nodes[i] = new JSONNode();
            ((UBJSONLoadable) this.nodes[i]).loadUBJSON(createObject);
        }
        if (contains) {
            int[] intArray6 = UBJSONUtil.toIntArray(uBObject.get("categories_segments"));
            int[] intArray7 = UBJSONUtil.toIntArray(uBObject.get("categories_sizes"));
            int[] intArray8 = UBJSONUtil.toIntArray(uBObject.get("categories_nodes"));
            int[] intArray9 = UBJSONUtil.toIntArray(uBObject.get("categories"));
            int i2 = 0;
            int i3 = intArray8[0];
            for (int i4 = 0; i4 < this.num_nodes; i4++) {
                JSONNode jSONNode = (JSONNode) this.nodes[i4];
                if (i4 == i3) {
                    int i5 = intArray6[i2];
                    int i6 = i5 + intArray7[i2];
                    int i7 = -1;
                    for (int i8 = i5; i8 < i6; i8++) {
                        i7 = Math.max(i7, intArray9[i8]);
                    }
                    if (i7 == -1) {
                        throw new IllegalArgumentException();
                    }
                    BitSet bitSet = new BitSet(i7 + 1);
                    for (int i9 = i5; i9 < i6; i9++) {
                        bitSet.set(intArray9[i9], true);
                    }
                    jSONNode.set_split_categories(bitSet);
                    i2++;
                    i3 = i2 == intArray8.length ? -1 : intArray8[i2];
                } else {
                    jSONNode.set_split_categories(null);
                }
            }
        }
    }

    public Float getLeafValue() {
        Node node = this.nodes[0];
        if (node.is_leaf()) {
            return Float.valueOf(node.leaf_value());
        }
        return null;
    }

    public boolean hasCategoricalSplits() {
        for (int i = 0; i < this.num_nodes; i++) {
            Node node = this.nodes[i];
            if (!node.is_leaf() && node.split_type() == 1) {
                return true;
            }
        }
        return false;
    }

    public Set<Integer> getSplitType(int i) {
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this.num_nodes; i2++) {
            Node node = this.nodes[i2];
            if (!node.is_leaf() && node.split_index() == i) {
                hashSet.add(Integer.valueOf(node.split_type()));
            }
        }
        return hashSet;
    }

    public BitSet getSplitCategories(int i) {
        BitSet bitSet;
        BitSet bitSet2 = null;
        for (int i2 = 0; i2 < this.num_nodes; i2++) {
            Node node = this.nodes[i2];
            if (!node.is_leaf() && node.split_index() == i && (bitSet = node.get_split_categories()) != null) {
                if (bitSet2 == null) {
                    bitSet2 = new BitSet();
                }
                bitSet2.or(bitSet);
            }
        }
        return bitSet2;
    }

    public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema) {
        return new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(0, True.INSTANCE, new CategoryManager(), predicateManager, schema)).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD).setMathContext(MathContext.FLOAT);
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Failed to find 'out' block for switch in B:61:0x02ff. Please report as an issue. */
    private org.dmg.pmml.tree.Node encodeNode(int i, Predicate predicate, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        Integer valueOf = Integer.valueOf(i);
        Node node = this.nodes[i];
        if (node.is_leaf()) {
            return new LeafNode(Float.valueOf(node.leaf_value() + 0.0f), predicate).setId(valueOf);
        }
        CategoricalFeature feature = schema.getFeature(node.split_index());
        boolean default_left = node.default_left();
        boolean z = false;
        CategoryManager categoryManager2 = categoryManager;
        CategoryManager categoryManager3 = categoryManager;
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = feature;
            if (node.split_type() != 1) {
                throw new IllegalArgumentException("Expected a categorical (1) split type for categorical feature '" + categoricalFeature.getName() + "', got non-categorical (" + node.split_type() + ")");
            }
        } else if (node.split_type() != 0) {
            throw new IllegalArgumentException("Expected a numerical (0) split type for feature '" + feature.getName() + "', got non-numerical (" + node.split_type() + ")");
        }
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature2 = feature;
            String name = categoricalFeature2.getName();
            List values = categoricalFeature2.getValues();
            if (!Float.valueOf(Float.intBitsToFloat(node.split_cond())).isNaN()) {
                throw new IllegalArgumentException();
            }
            BitSet bitSet = node.get_split_categories();
            if (bitSet == null) {
                throw new IllegalArgumentException();
            }
            java.util.function.Predicate valueFilter = categoryManager.getValueFilter(name);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < values.size(); i2++) {
                Object obj = values.get(i2);
                if (valueFilter.test(obj)) {
                    if (bitSet.get(i2)) {
                        arrayList2.add(obj);
                    } else {
                        arrayList.add(obj);
                    }
                }
            }
            categoryManager2 = categoryManager2.fork(name, arrayList);
            categoryManager3 = categoryManager3.fork(name, arrayList2);
            createSimplePredicate = predicateManager.createPredicate(categoricalFeature2, arrayList);
            createSimplePredicate2 = predicateManager.createPredicate(categoricalFeature2, arrayList2);
        } else if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = (BinaryFeature) feature;
            Object value = binaryFeature.getValue();
            createSimplePredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            createSimplePredicate2 = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
        } else if (feature instanceof MissingValueFeature) {
            MissingValueFeature missingValueFeature = (MissingValueFeature) feature;
            createSimplePredicate = predicateManager.createSimplePredicate(missingValueFeature, SimplePredicate.Operator.IS_NOT_MISSING, (Object) null);
            createSimplePredicate2 = predicateManager.createSimplePredicate(missingValueFeature, SimplePredicate.Operator.IS_MISSING, (Object) null);
        } else if (feature instanceof ThresholdFeature) {
            ThresholdFeature thresholdFeature = (ThresholdFeature) feature;
            String name2 = thresholdFeature.getName();
            Object missingValue = thresholdFeature.getMissingValue();
            Float valueOf2 = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
            java.util.function.Predicate valueFilter2 = categoryManager.getValueFilter(name2);
            if (!ValueUtil.isNaN(missingValue)) {
                valueFilter2 = valueFilter2.and(obj2 -> {
                    return !ValueUtil.isNaN(obj2);
                });
            }
            List list = (List) thresholdFeature.getValues(number -> {
                return number.floatValue() < valueOf2.floatValue();
            }).stream().filter(valueFilter2).collect(Collectors.toList());
            List list2 = (List) thresholdFeature.getValues(number2 -> {
                return number2.floatValue() >= valueOf2.floatValue();
            }).stream().filter(valueFilter2).collect(Collectors.toList());
            categoryManager2 = categoryManager2.fork(name2, list);
            categoryManager3 = categoryManager3.fork(name2, list2);
            createSimplePredicate = ThresholdFeatureUtil.createPredicate(thresholdFeature, list, missingValue, predicateManager);
            createSimplePredicate2 = ThresholdFeatureUtil.createPredicate(thresholdFeature, list2, missingValue, predicateManager);
            if (!ThresholdFeatureUtil.isMissingValueSafe(createSimplePredicate) && ThresholdFeatureUtil.isMissingValueSafe(createSimplePredicate2)) {
                z = true;
            }
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Number valueOf3 = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
            DataType dataType = continuousFeature.getDataType();
            switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$DataType[dataType.ordinal()]) {
                case Node.SPLIT_CATEGORICAL /* 1 */:
                    Float valueOf4 = Float.valueOf((float) Math.floor(valueOf3.floatValue()));
                    valueOf3 = valueOf3.floatValue() == valueOf4.floatValue() ? Integer.valueOf((int) valueOf4.floatValue()) : Integer.valueOf((int) (valueOf4.floatValue() + 1.0f));
                    createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, valueOf3);
                    createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, valueOf3);
                    break;
                case 2:
                    createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, valueOf3);
                    createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, valueOf3);
                    break;
                case 3:
                    continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
                    createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, valueOf3);
                    createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, valueOf3);
                    break;
                default:
                    throw new IllegalArgumentException("Expected integer or floating-point data type for continuous feature '" + continuousFeature.getName() + "', got " + dataType.value() + " data type");
            }
        }
        org.dmg.pmml.tree.Node encodeNode = encodeNode(node.left_child(), createSimplePredicate, categoryManager2, predicateManager, schema);
        org.dmg.pmml.tree.Node encodeNode2 = encodeNode(node.right_child(), createSimplePredicate2, categoryManager3, predicateManager, schema);
        org.dmg.pmml.tree.Node addNodes = new BranchNode((Object) null, predicate).setId(valueOf).setDefaultChild(default_left ? encodeNode.getId() : encodeNode2.getId()).addNodes(encodeNode, encodeNode2);
        if (z) {
            Collections.swap(addNodes.getNodes(), 0, 1);
        }
        return addNodes;
    }
}
