package sklearn.ensemble.hist_gradient_boosting;

import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import java.util.Arrays;
import java.util.List;
import org.jpmml.python.PythonObject;

/* loaded from: input_file:sklearn/ensemble/hist_gradient_boosting/TreePredictor.class */
public class TreePredictor extends PythonObject {
    public static final List<String> DTYPE_PREDICTOR_OLD = Arrays.asList("value", "count", "feature_idx", "threshold", "missing_go_to_left", "left", "right", "gain", "depth", "is_leaf", "bin_threshold");
    public static final List<String> DTYPE_PREDICTOR_NEW = Arrays.asList("value", "count", "feature_idx", "num_threshold", "missing_go_to_left", "left", "right", "gain", "depth", "is_leaf", "bin_threshold", "is_categorical", "bitset_idx");

    public TreePredictor(String str, String str2) {
        super(str, str2);
    }

    public int[] getRawLeftCatBitsets() {
        if (hasattr("raw_left_cat_bitsets")) {
            return Ints.toArray(getIntegerArray("raw_left_cat_bitsets"));
        }
        return null;
    }

    public double[] getValues() {
        return Doubles.toArray(getNodeAttribute("value"));
    }

    public int[] getCount() {
        return Ints.toArray(getNodeAttribute("count"));
    }

    public int[] getFeatureIdx() {
        return Ints.toArray(getNodeAttribute("feature_idx"));
    }

    public double[] getThreshold() {
        List<Number> nodeAttribute = getNodeAttribute("threshold");
        return nodeAttribute != null ? Doubles.toArray(nodeAttribute) : Doubles.toArray(getNodeAttribute("num_threshold"));
    }

    public int[] getMissingGoToLeft() {
        return Ints.toArray(getNodeAttribute("missing_go_to_left"));
    }

    public int[] getLeft() {
        return Ints.toArray(getNodeAttribute("left"));
    }

    public int[] getRight() {
        return Ints.toArray(getNodeAttribute("right"));
    }

    public int[] isLeaf() {
        return Ints.toArray(getNodeAttribute("is_leaf"));
    }

    public int[] getBinThreshhold() {
        return Ints.toArray(getNodeAttribute("bin_threshold"));
    }

    public int[] isCategorical() {
        List<Number> nodeAttribute = getNodeAttribute("is_categorical");
        if (nodeAttribute == null) {
            return null;
        }
        return Ints.toArray(nodeAttribute);
    }

    public int[] getBitsetIdx() {
        List<Number> nodeAttribute = getNodeAttribute("bitset_idx");
        if (nodeAttribute == null) {
            return null;
        }
        return Ints.toArray(nodeAttribute);
    }

    private List<Number> getNodeAttribute(String str) {
        return getArray("nodes", str);
    }
}
