package org.jpmml.lightgbm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/lightgbm/Tree.class */
public class Tree {
    private int num_leaves_;
    private int num_cat_;
    private int[] left_child_;
    private int[] right_child_;
    private int[] split_feature_real_;
    private double[] threshold_;
    private int[] decision_type_;
    private double[] leaf_value_;
    private int[] leaf_count_;
    private double[] internal_value_;
    private int[] internal_count_;
    private int[] cat_boundaries_;
    private long[] cat_threshold_;
    private static final int MASK_CATEGORICAL = 1;
    private static final int MASK_DEFAULT_LEFT = 2;

    public void load(Section section) {
        this.num_leaves_ = section.getInt("num_leaves");
        this.num_cat_ = section.getInt("num_cat");
        this.left_child_ = section.getIntArray("left_child", this.num_leaves_ - MASK_CATEGORICAL);
        this.right_child_ = section.getIntArray("right_child", this.num_leaves_ - MASK_CATEGORICAL);
        this.split_feature_real_ = section.getIntArray("split_feature", this.num_leaves_ - MASK_CATEGORICAL);
        this.threshold_ = section.getDoubleArray("threshold", this.num_leaves_ - MASK_CATEGORICAL);
        this.decision_type_ = section.getIntArray("decision_type", this.num_leaves_ - MASK_CATEGORICAL);
        this.leaf_value_ = section.getDoubleArray("leaf_value", this.num_leaves_);
        this.leaf_count_ = section.getIntArray("leaf_count", this.num_leaves_);
        this.internal_value_ = section.getDoubleArray("internal_value", this.num_leaves_ - MASK_CATEGORICAL);
        this.internal_count_ = section.getIntArray("internal_count", this.num_leaves_ - MASK_CATEGORICAL);
        if (this.num_cat_ > 0) {
            this.cat_boundaries_ = section.getIntArray("cat_boundaries", this.num_cat_ + MASK_CATEGORICAL);
            this.cat_threshold_ = section.getUnsignedIntArray("cat_threshold", -1);
        }
    }

    public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema) {
        Node predicate = new Node().setPredicate(new True());
        encodeNode(predicate, predicateManager, Collections.emptyMap(), 0, schema);
        return new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), predicate).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
    }

    public void encodeNode(Node node, PredicateManager predicateManager, Map<FieldName, Set<String>> map, int i, Schema schema) {
        Predicate createSimplePredicate;
        Predicate createSimplePredicate2;
        node.setId(String.valueOf(i));
        if (i < 0) {
            node.setScore(ValueUtil.formatValue(Double.valueOf(this.leaf_value_[i ^ (-1)])));
            node.setRecordCount(Double.valueOf(this.leaf_count_[r0]));
            return;
        }
        node.setScore((String) null);
        node.setRecordCount(Double.valueOf(this.internal_count_[i]));
        BinaryFeature feature = schema.getFeature(this.split_feature_real_[i]);
        double d = this.threshold_[i];
        int i2 = this.decision_type_[i];
        Map<FieldName, Set<String>> map2 = map;
        Map<FieldName, Set<String>> map3 = map;
        boolean hasDefaultLeftMask = hasDefaultLeftMask(i2);
        if (feature instanceof BinaryFeature) {
            BinaryFeature binaryFeature = feature;
            if (hasCategoricalMask(i2) || d != 0.5d) {
                throw new IllegalArgumentException();
            }
            String value = binaryFeature.getValue();
            createSimplePredicate = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
            createSimplePredicate2 = predicateManager.createSimplePredicate(binaryFeature, SimplePredicate.Operator.EQUAL, value);
        } else if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
            if (!hasCategoricalMask(i2)) {
                throw new IllegalArgumentException();
            }
            FieldName name = categoricalFeature.getName();
            boolean z = categoricalFeature instanceof DirectCategoricalFeature;
            List<String> values = categoricalFeature.getValues();
            Set<String> set = map.get(name);
            if (set == null) {
                set = new HashSet(values);
            }
            int asInt = ValueUtil.asInt(Double.valueOf(d));
            List<String> selectValues = selectValues(z, values, set, asInt, true);
            List<String> selectValues2 = selectValues(z, values, set, asInt, false);
            map2 = new HashMap((Map<? extends FieldName, ? extends Set<String>>) map);
            map2.put(name, new HashSet(selectValues));
            map3 = new HashMap((Map<? extends FieldName, ? extends Set<String>>) map);
            map3.put(name, new HashSet(selectValues2));
            createSimplePredicate = predicateManager.createSimpleSetPredicate(categoricalFeature, selectValues);
            createSimplePredicate2 = predicateManager.createSimpleSetPredicate(categoricalFeature, selectValues2);
            hasDefaultLeftMask = false;
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            if (hasCategoricalMask(i2)) {
                throw new IllegalArgumentException();
            }
            String formatValue = ValueUtil.formatValue(Double.valueOf(d));
            createSimplePredicate = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, formatValue);
            createSimplePredicate2 = predicateManager.createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, formatValue);
        }
        Node predicate = new Node().setPredicate(createSimplePredicate);
        encodeNode(predicate, predicateManager, map2, this.left_child_[i], schema);
        Node predicate2 = new Node().setPredicate(createSimplePredicate2);
        encodeNode(predicate2, predicateManager, map3, this.right_child_[i], schema);
        node.addNodes(new Node[]{predicate, predicate2});
        node.setDefaultChild(hasDefaultLeftMask ? predicate.getId() : predicate2.getId());
    }

    private List<String> selectValues(boolean z, List<String> list, Set<String> set, int i, boolean z2) {
        ArrayList arrayList = z2 ? new ArrayList() : new ArrayList(list);
        int i2 = this.cat_boundaries_[i + MASK_CATEGORICAL] - this.cat_boundaries_[i];
        for (int i3 = 0; i3 < i2; i3 += MASK_CATEGORICAL) {
            for (int i4 = 0; i4 < 32; i4 += MASK_CATEGORICAL) {
                int i5 = (i3 * 32) + i4;
                if (findInBitset(this.cat_threshold_, this.cat_boundaries_[i], i2, i5)) {
                    String str = z ? (String) LightGBMUtil.CATEGORY_FORMATTER.apply(Integer.valueOf(i5)) : list.get(i5);
                    if (z2) {
                        arrayList.add(str);
                    } else {
                        arrayList.remove(str);
                    }
                }
            }
        }
        arrayList.retainAll(set);
        if (z2) {
            if (arrayList.isEmpty()) {
                throw new IllegalArgumentException();
            }
        } else if (arrayList.equals(set)) {
            throw new IllegalArgumentException();
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Boolean isBinary(int i) {
        Boolean bool = null;
        for (int i2 = 0; i2 < this.split_feature_real_.length; i2 += MASK_CATEGORICAL) {
            if (this.split_feature_real_[i2] == i) {
                if (!hasCategoricalMask(this.decision_type_[i2]) && this.threshold_[i2] == 0.5d) {
                    bool = Boolean.TRUE;
                }
                return Boolean.FALSE;
            }
        }
        return bool;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Boolean isCategorical(int i) {
        Boolean bool = null;
        for (int i2 = 0; i2 < this.split_feature_real_.length; i2 += MASK_CATEGORICAL) {
            if (this.split_feature_real_[i2] == i) {
                if (!hasCategoricalMask(this.decision_type_[i2])) {
                    return Boolean.FALSE;
                }
                bool = Boolean.TRUE;
            }
        }
        return bool;
    }

    private static boolean hasCategoricalMask(int i) {
        return getDecisionType(i, MASK_CATEGORICAL) == MASK_CATEGORICAL;
    }

    private static boolean hasDefaultLeftMask(int i) {
        return getDecisionType(i, MASK_DEFAULT_LEFT) == MASK_DEFAULT_LEFT;
    }

    static int getDecisionType(int i, int i2) {
        return i & i2;
    }

    static int getMissingType(int i) {
        return getDecisionType(i >> MASK_DEFAULT_LEFT, 3);
    }

    private static boolean findInBitset(long[] jArr, int i, int i2, int i3) {
        int i4 = i3 / 32;
        if (i4 >= i2) {
            return false;
        }
        return ((jArr[i + i4] >> (i3 % 32)) & 1) == 1;
    }
}
