package org.jpmml.xgboost;

import com.devsmart.ubjson.GsonUtil;
import com.devsmart.ubjson.UBArray;
import com.devsmart.ubjson.UBObject;
import com.google.common.primitives.Floats;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Schema;

/* loaded from: input_file:org/jpmml/xgboost/GBTree.class */
public class GBTree extends GradientBooster {
    private int num_trees;
    private int num_roots;
    private int num_feature;
    private int num_output_group;
    private int size_leaf_vector;
    private RegTree[] trees;
    private int[] tree_info;

    @Override // org.jpmml.xgboost.GradientBooster
    public String getAlgorithmName() {
        return "GBTree";
    }

    @Override // org.jpmml.xgboost.BinaryLoadable
    public void loadBinary(XGBoostDataInput xGBoostDataInput) throws IOException {
        this.num_trees = xGBoostDataInput.readInt();
        this.num_roots = xGBoostDataInput.readInt();
        this.num_feature = xGBoostDataInput.readInt();
        xGBoostDataInput.readReserved(3);
        this.num_output_group = xGBoostDataInput.readInt();
        this.size_leaf_vector = xGBoostDataInput.readInt();
        xGBoostDataInput.readReserved(32);
        this.trees = (RegTree[]) xGBoostDataInput.readObjectArray(RegTree.class, this.num_trees);
        this.tree_info = xGBoostDataInput.readIntArray(this.num_trees);
    }

    public void loadJSON(JsonObject jsonObject) {
        loadUBJSON(GsonUtil.toUBValue(jsonObject).asObject());
    }

    public void loadUBJSON(UBObject uBObject) {
        UBObject asObject = uBObject.get("model").asObject();
        UBObject asObject2 = asObject.get("gbtree_model_param").asObject();
        this.num_trees = asObject2.get("num_trees").asInt();
        this.size_leaf_vector = asObject2.get("size_leaf_vector").asInt();
        UBArray asArray = asObject.get("trees").asArray();
        this.trees = new RegTree[this.num_trees];
        for (int i = 0; i < this.num_trees; i++) {
            UBObject asObject3 = asArray.get(i).asObject();
            this.trees[i] = new RegTree();
            this.trees[i].loadUBJSON(asObject3);
        }
        this.tree_info = UBJSONUtil.toIntArray(asObject.get("tree_info"));
    }

    public boolean hasCategoricalSplits() {
        for (int i = 0; i < this.num_trees; i++) {
            if (this.trees[i].hasCategoricalSplits()) {
                return true;
            }
        }
        return false;
    }

    public Set<Integer> getSplitType(int i) {
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this.num_trees; i2++) {
            hashSet.addAll(this.trees[i2].getSplitType(i));
        }
        return hashSet;
    }

    public BitSet getSplitCategories(int i) {
        BitSet bitSet = null;
        for (int i2 = 0; i2 < this.num_trees; i2++) {
            BitSet splitCategories = this.trees[i2].getSplitCategories(i);
            if (splitCategories != null) {
                if (bitSet == null) {
                    bitSet = new BitSet();
                }
                bitSet.or(splitCategories);
            }
        }
        return bitSet;
    }

    public MiningModel encodeMiningModel(ObjFunction objFunction, float f, Integer num, boolean z, Schema schema) {
        RegTree[] trees = trees();
        float[] tree_weights = tree_weights();
        return objFunction.encodeMiningModel(Arrays.asList(trees), tree_weights != null ? Floats.asList(tree_weights) : null, f, num, z, schema);
    }

    public int num_trees() {
        return this.num_trees;
    }

    public RegTree[] trees() {
        return this.trees;
    }

    public float[] tree_weights() {
        return null;
    }
}
