package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.DataInfo;
import hex.KeyValue;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.PlattScalingHelper;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.predict.AssignLeafNodeTask;
import hex.tree.xgboost.predict.MutableOneHotEncoderFVec;
import hex.tree.xgboost.predict.PredictTreeSHAPTask;
import hex.tree.xgboost.predict.XGBoostBigScorePredict;
import hex.tree.xgboost.predict.XGBoostJavaBigScorePredict;
import hex.tree.xgboost.predict.XGBoostModelMetrics;
import hex.tree.xgboost.predict.XGBoostNativeBigScorePredict;
import hex.tree.xgboost.util.BoosterHelper;
import hex.tree.xgboost.util.PredictConfiguration;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.stream.Stream;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.PredictorFactory;
import ml.dmlc.xgboost4j.java.XGBoostModelInfo;
import okhttp3.internal.cache.DiskLruCache;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.H2ONode;
import water.IcedUtils;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.Keyed;
import water.codegen.CodeGeneratorPipeline;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.JCodeGen;
import water.util.Log;
import water.util.SBPrintStream;

/* loaded from: input_file:hex/tree/xgboost/XGBoostModel.class */
public class XGBoostModel extends Model<XGBoostModel, XGBoostParameters, XGBoostOutput> implements SharedTreeGraphConverter, Model.LeafNodeAssignment, Model.Contributions {
    private static final String PROP_VERBOSITY = "sys.ai.h2o..xgboost.verbosity";
    private static final String PROP_NTHREAD = "sys.ai.h2o.xgboost.nthreadMax";
    private XGBoostModelInfo model_info;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters.class */
    public static class XGBoostParameters extends Model.Parameters implements Model.GetNTrees, PlattScalingHelper.ParamsWithCalibration {
        public int _n_estimators;
        public KeyValue[] _monotone_constraints;
        public float _gamma;
        public String _save_matrix_directory;
        public boolean _calibrate_model;
        public Key<Frame> _calibration_frame;
        static String[] CHECKPOINT_NON_MODIFIABLE_FIELDS = {"_tree_method", "_grow_policy", "_booster", "_sample_rate", "_max_depth", "_min_rows"};
        public boolean _quiet_mode = true;
        public int _ntrees = 50;
        public int _max_depth = 6;
        public double _min_rows = 1.0d;
        public double _min_child_weight = 1.0d;
        public double _learn_rate = 0.3d;
        public double _eta = 0.3d;
        public double _learn_rate_annealing = 1.0d;
        public double _sample_rate = 1.0d;
        public double _subsample = 1.0d;
        public double _col_sample_rate = 1.0d;
        public double _colsample_bylevel = 1.0d;
        public double _col_sample_rate_per_tree = 1.0d;
        public double _colsample_bytree = 1.0d;
        public float _max_abs_leafnode_pred = 0.0f;
        public float _max_delta_step = 0.0f;
        public int _score_tree_interval = 0;
        public int _initial_score_interval = 4000;
        public int _score_interval = 4000;
        public float _min_split_improvement = 0.0f;
        public int _nthread = -1;
        public boolean _build_tree_one_node = false;
        public int _max_bins = 256;
        public int _max_leaves = 0;
        public float _min_sum_hessian_in_leaf = 100.0f;
        public float _min_data_in_leaf = 0.0f;
        public TreeMethod _tree_method = TreeMethod.auto;
        public GrowPolicy _grow_policy = GrowPolicy.depthwise;
        public Booster _booster = Booster.gbtree;
        public DMatrixType _dmatrix_type = DMatrixType.auto;
        public float _reg_lambda = 1.0f;
        public float _reg_alpha = 0.0f;
        public DartSampleType _sample_type = DartSampleType.uniform;
        public DartNormalizeType _normalize_type = DartNormalizeType.tree;
        public float _rate_drop = 0.0f;
        public boolean _one_drop = false;
        public float _skip_drop = 0.0f;
        public int _gpu_id = 0;
        public Backend _backend = Backend.auto;

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$Backend.class */
        public enum Backend {
            auto,
            gpu,
            cpu
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$Booster.class */
        public enum Booster {
            gbtree,
            gblinear,
            dart
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$DMatrixType.class */
        public enum DMatrixType {
            auto,
            dense,
            sparse
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$DartNormalizeType.class */
        public enum DartNormalizeType {
            tree,
            forest
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$DartSampleType.class */
        public enum DartSampleType {
            uniform,
            weighted
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$GrowPolicy.class */
        public enum GrowPolicy {
            depthwise,
            lossguide
        }

        /* loaded from: input_file:hex/tree/xgboost/XGBoostModel$XGBoostParameters$TreeMethod.class */
        public enum TreeMethod {
            auto,
            exact,
            approx,
            hist
        }

        @Override // hex.Model.Parameters
        public String algoName() {
            return "XGBoost";
        }

        @Override // hex.Model.Parameters
        public String fullName() {
            return "XGBoost";
        }

        @Override // hex.Model.Parameters
        public String javaName() {
            return XGBoostModel.class.getName();
        }

        @Override // hex.Model.Parameters
        public long progressUnits() {
            return this._ntrees;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Map<String, Object> gpuIncompatibleParams() {
            HashMap hashMap = new HashMap();
            if (TreeMethod.auto != this._tree_method && TreeMethod.hist != this._tree_method && Booster.gblinear != this._booster) {
                hashMap.put("tree_method", "Only auto and hist are supported tree_method on GPU backend.");
            }
            if (this._max_depth > 15 || this._max_depth < 1) {
                hashMap.put("max_depth", this._max_depth + " . Max depth must be greater than 0 and lower than 16 for GPU backend.");
            }
            if (this._grow_policy == GrowPolicy.lossguide) {
                hashMap.put("grow_policy", GrowPolicy.lossguide);
            }
            return hashMap;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Map<String, Integer> monotoneConstraints() {
            if (this._monotone_constraints == null || this._monotone_constraints.length == 0) {
                return Collections.emptyMap();
            }
            HashMap hashMap = new HashMap(this._monotone_constraints.length);
            for (KeyValue keyValue : this._monotone_constraints) {
                double value = keyValue.getValue();
                if (value != 0.0d) {
                    if (hashMap.containsKey(keyValue.getKey())) {
                        throw new IllegalStateException("Duplicate definition of constraint for feature '" + keyValue.getKey() + "'.");
                    }
                    hashMap.put(keyValue.getKey(), Integer.valueOf(value < 0.0d ? -1 : 1));
                }
            }
            return hashMap;
        }

        @Override // hex.Model.GetNTrees
        public int getNTrees() {
            return this._ntrees;
        }

        @Override // hex.tree.PlattScalingHelper.ParamsWithCalibration
        public Frame getCalibrationFrame() {
            if (this._calibration_frame != null) {
                return this._calibration_frame.get();
            }
            return null;
        }

        @Override // hex.tree.PlattScalingHelper.ParamsWithCalibration
        public boolean calibrateModel() {
            return this._calibrate_model;
        }

        @Override // hex.tree.PlattScalingHelper.ParamsWithCalibration
        public Model.Parameters getParams() {
            return this;
        }
    }

    public XGBoostModelInfo model_info() {
        return this.model_info;
    }

    @Override // hex.Model
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        switch (((XGBoostOutput) this._output).getModelCategory()) {
            case Binomial:
                return new ModelMetricsBinomial.MetricBuilderBinomial(strArr);
            case Multinomial:
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((XGBoostOutput) this._output).nclasses(), strArr);
            case Regression:
                return new ModelMetricsRegression.MetricBuilderRegression();
            default:
                throw H2O.unimpl();
        }
    }

    public XGBoostModel(Key<XGBoostModel> key, XGBoostParameters xGBoostParameters, XGBoostOutput xGBoostOutput, Frame frame, Frame frame2) {
        super(key, xGBoostParameters, xGBoostOutput);
        DataInfo makeDataInfo = XGBoost.makeDataInfo(frame, frame2, (XGBoostParameters) this._parms, xGBoostOutput.nclasses());
        DKV.put(makeDataInfo);
        setDataInfoToOutput(makeDataInfo);
        this.model_info = new XGBoostModelInfo(xGBoostParameters, makeDataInfo);
    }

    public void dump(String str) {
        File file = null;
        try {
            try {
                Booster loadModel = BoosterHelper.loadModel(new ByteArrayInputStream(this.model_info._boosterBytes));
                file = File.createTempFile("xgboost-feature-map", ".bin");
                FileOutputStream fileOutputStream = new FileOutputStream(file);
                fileOutputStream.write(this.model_info._featureMap.getBytes());
                fileOutputStream.close();
                for (String str2 : loadModel.getModelDump(file.getAbsolutePath(), true, str)) {
                    System.out.println(str2);
                }
                if (file != null) {
                    file.delete();
                }
            } catch (Exception e) {
                Log.err(e);
                if (file != null) {
                    file.delete();
                }
            }
        } catch (Throwable th) {
            if (file != null) {
                file.delete();
            }
            throw th;
        }
    }

    public static XGBoostParameters.Backend getActualBackend(XGBoostParameters xGBoostParameters) {
        if (xGBoostParameters._backend != XGBoostParameters.Backend.auto && xGBoostParameters._backend != XGBoostParameters.Backend.gpu) {
            Log.info("Using CPU backend.");
            return XGBoostParameters.Backend.cpu;
        }
        if (H2O.getCloudSize() > 1) {
            Log.info("GPU backend not supported in distributed mode. Using CPU backend.");
            return XGBoostParameters.Backend.cpu;
        }
        if (!xGBoostParameters.gpuIncompatibleParams().isEmpty()) {
            Log.info("GPU backend not supported for the choice of parameters (" + xGBoostParameters.gpuIncompatibleParams() + "). Using CPU backend.");
            return XGBoostParameters.Backend.cpu;
        }
        if (XGBoost.hasGPU(H2O.CLOUD.members()[0], xGBoostParameters._gpu_id)) {
            Log.info("Using GPU backend (gpu_id: " + xGBoostParameters._gpu_id + ").");
            return XGBoostParameters.Backend.gpu;
        }
        Log.info("No GPU (gpu_id: " + xGBoostParameters._gpu_id + ") found. Using CPU backend.");
        return XGBoostParameters.Backend.cpu;
    }

    public static BoosterParms createParams(XGBoostParameters xGBoostParameters, int i, String[] strArr) {
        String str;
        HashMap hashMap = new HashMap();
        if (xGBoostParameters._n_estimators != 0) {
            Log.info("Using user-provided parameter n_estimators instead of ntrees.");
            hashMap.put("nround", Integer.valueOf(xGBoostParameters._n_estimators));
            xGBoostParameters._ntrees = xGBoostParameters._n_estimators;
        } else {
            hashMap.put("nround", Integer.valueOf(xGBoostParameters._ntrees));
            xGBoostParameters._n_estimators = xGBoostParameters._ntrees;
        }
        if (xGBoostParameters._eta != 0.3d) {
            Log.info("Using user-provided parameter eta instead of learn_rate.");
            hashMap.put("eta", Double.valueOf(xGBoostParameters._eta));
            xGBoostParameters._learn_rate = xGBoostParameters._eta;
        } else {
            hashMap.put("eta", Double.valueOf(xGBoostParameters._learn_rate));
            xGBoostParameters._eta = xGBoostParameters._learn_rate;
        }
        hashMap.put("max_depth", Integer.valueOf(xGBoostParameters._max_depth));
        if (System.getProperty(PROP_VERBOSITY) != null) {
            hashMap.put("verbosity", System.getProperty(PROP_VERBOSITY));
        } else {
            hashMap.put("silent", Boolean.valueOf(xGBoostParameters._quiet_mode));
        }
        if (xGBoostParameters._subsample != 1.0d) {
            Log.info("Using user-provided parameter subsample instead of sample_rate.");
            hashMap.put("subsample", Double.valueOf(xGBoostParameters._subsample));
            xGBoostParameters._sample_rate = xGBoostParameters._subsample;
        } else {
            hashMap.put("subsample", Double.valueOf(xGBoostParameters._sample_rate));
            xGBoostParameters._subsample = xGBoostParameters._sample_rate;
        }
        if (xGBoostParameters._colsample_bytree != 1.0d) {
            Log.info("Using user-provided parameter colsample_bytree instead of col_sample_rate_per_tree.");
            hashMap.put("colsample_bytree", Double.valueOf(xGBoostParameters._colsample_bytree));
            xGBoostParameters._col_sample_rate_per_tree = xGBoostParameters._colsample_bytree;
        } else {
            hashMap.put("colsample_bytree", Double.valueOf(xGBoostParameters._col_sample_rate_per_tree));
            xGBoostParameters._colsample_bytree = xGBoostParameters._col_sample_rate_per_tree;
        }
        if (xGBoostParameters._colsample_bylevel != 1.0d) {
            Log.info("Using user-provided parameter colsample_bylevel instead of col_sample_rate.");
            hashMap.put("colsample_bylevel", Double.valueOf(xGBoostParameters._colsample_bylevel));
            xGBoostParameters._col_sample_rate = xGBoostParameters._colsample_bylevel;
        } else {
            hashMap.put("colsample_bylevel", Double.valueOf(xGBoostParameters._col_sample_rate));
            xGBoostParameters._colsample_bylevel = xGBoostParameters._col_sample_rate;
        }
        if (xGBoostParameters._max_delta_step != 0.0f) {
            Log.info("Using user-provided parameter max_delta_step instead of max_abs_leafnode_pred.");
            hashMap.put("max_delta_step", Float.valueOf(xGBoostParameters._max_delta_step));
            xGBoostParameters._max_abs_leafnode_pred = xGBoostParameters._max_delta_step;
        } else {
            hashMap.put("max_delta_step", Float.valueOf(xGBoostParameters._max_abs_leafnode_pred));
            xGBoostParameters._max_delta_step = xGBoostParameters._max_abs_leafnode_pred;
        }
        hashMap.put("seed", Integer.valueOf((int) (xGBoostParameters._seed % 2147483647L)));
        hashMap.put("grow_policy", xGBoostParameters._grow_policy.toString());
        if (xGBoostParameters._grow_policy == XGBoostParameters.GrowPolicy.lossguide) {
            hashMap.put("max_bin", Integer.valueOf(xGBoostParameters._max_bins));
            hashMap.put("max_leaves", Integer.valueOf(xGBoostParameters._max_leaves));
            hashMap.put("min_sum_hessian_in_leaf", Float.valueOf(xGBoostParameters._min_sum_hessian_in_leaf));
            hashMap.put("min_data_in_leaf", Float.valueOf(xGBoostParameters._min_data_in_leaf));
        }
        hashMap.put("booster", xGBoostParameters._booster.toString());
        if (xGBoostParameters._booster == XGBoostParameters.Booster.dart) {
            hashMap.put("sample_type", xGBoostParameters._sample_type.toString());
            hashMap.put("normalize_type", xGBoostParameters._normalize_type.toString());
            hashMap.put("rate_drop", Float.valueOf(xGBoostParameters._rate_drop));
            hashMap.put("one_drop", xGBoostParameters._one_drop ? DiskLruCache.VERSION_1 : "0");
            hashMap.put("skip_drop", Float.valueOf(xGBoostParameters._skip_drop));
        }
        if (getActualBackend(xGBoostParameters) == XGBoostParameters.Backend.gpu) {
            hashMap.put("gpu_id", Integer.valueOf(xGBoostParameters._gpu_id));
            if (xGBoostParameters._booster == XGBoostParameters.Booster.gblinear) {
                Log.info("Using gpu_coord_descent updater.");
                hashMap.put("updater", "gpu_coord_descent");
            } else {
                Log.info("Using gpu_hist tree method.");
                hashMap.put("max_bin", Integer.valueOf(xGBoostParameters._max_bins));
                hashMap.put("updater", "grow_gpu_hist");
            }
        } else if (xGBoostParameters._booster == XGBoostParameters.Booster.gblinear) {
            Log.info("Using coord_descent updater.");
            hashMap.put("updater", "coord_descent");
        } else if (H2O.CLOUD.size() <= 1 || xGBoostParameters._tree_method != XGBoostParameters.TreeMethod.auto || xGBoostParameters._monotone_constraints == null) {
            Log.info("Using " + xGBoostParameters._tree_method.toString() + " tree method.");
            hashMap.put("tree_method", xGBoostParameters._tree_method.toString());
            if (xGBoostParameters._tree_method == XGBoostParameters.TreeMethod.hist) {
                hashMap.put("max_bin", Integer.valueOf(xGBoostParameters._max_bins));
            }
        } else {
            Log.info("Using hist tree method for distributed computation with monotone_constraints.");
            hashMap.put("tree_method", XGBoostParameters.TreeMethod.hist.toString());
            hashMap.put("max_bin", Integer.valueOf(xGBoostParameters._max_bins));
        }
        if (xGBoostParameters._min_child_weight != 1.0d) {
            Log.info("Using user-provided parameter min_child_weight instead of min_rows.");
            hashMap.put("min_child_weight", Double.valueOf(xGBoostParameters._min_child_weight));
            xGBoostParameters._min_rows = xGBoostParameters._min_child_weight;
        } else {
            hashMap.put("min_child_weight", Double.valueOf(xGBoostParameters._min_rows));
            xGBoostParameters._min_child_weight = xGBoostParameters._min_rows;
        }
        if (xGBoostParameters._gamma != 0.0f) {
            Log.info("Using user-provided parameter gamma instead of min_split_improvement.");
            hashMap.put("gamma", Float.valueOf(xGBoostParameters._gamma));
            xGBoostParameters._min_split_improvement = xGBoostParameters._gamma;
        } else {
            hashMap.put("gamma", Float.valueOf(xGBoostParameters._min_split_improvement));
            xGBoostParameters._gamma = xGBoostParameters._min_split_improvement;
        }
        hashMap.put("lambda", Float.valueOf(xGBoostParameters._reg_lambda));
        hashMap.put("alpha", Float.valueOf(xGBoostParameters._reg_alpha));
        if (i == 2) {
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId());
        } else if (i != 1) {
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId());
            hashMap.put("num_class", Integer.valueOf(i));
        } else if (xGBoostParameters._distribution == DistributionFamily.gamma) {
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId());
        } else if (xGBoostParameters._distribution == DistributionFamily.tweedie) {
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId());
            hashMap.put("tweedie_variance_power", Double.valueOf(xGBoostParameters._tweedie_power));
        } else if (xGBoostParameters._distribution == DistributionFamily.poisson) {
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId());
        } else {
            if (xGBoostParameters._distribution != DistributionFamily.gaussian && xGBoostParameters._distribution != DistributionFamily.AUTO) {
                throw new UnsupportedOperationException("No support for distribution=" + xGBoostParameters._distribution.toString());
            }
            hashMap.put("objective", XGBoostMojoModel.ObjectiveType.REG_SQUAREDERROR.getId());
        }
        if (!$assertionsDisabled && XGBoostMojoModel.ObjectiveType.fromXGBoost((String) hashMap.get("objective")) == null) {
            throw new AssertionError();
        }
        int maxNThread = getMaxNThread();
        int min = xGBoostParameters._nthread != -1 ? Math.min(xGBoostParameters._nthread, maxNThread) : maxNThread;
        if (min < xGBoostParameters._nthread) {
            Log.warn("Requested nthread=" + xGBoostParameters._nthread + " but the cluster has only " + maxNThread + " available.Training will use nthread=" + min + " instead of the user specified value.");
        }
        hashMap.put("nthread", Integer.valueOf(min));
        Map<String, Integer> monotoneConstraints = xGBoostParameters.monotoneConstraints();
        if (!monotoneConstraints.isEmpty()) {
            int i2 = 0;
            StringBuilder sb = new StringBuilder();
            sb.append("(");
            for (String str2 : strArr) {
                if (monotoneConstraints.containsKey(str2)) {
                    str = monotoneConstraints.get(str2).toString();
                    i2++;
                } else {
                    str = "0";
                }
                sb.append(str);
                sb.append(",");
            }
            sb.replace(sb.length() - 1, sb.length(), ")");
            hashMap.put("monotone_constraints", sb.toString());
            if (!$assertionsDisabled && i2 != monotoneConstraints.size()) {
                throw new AssertionError();
            }
        }
        Log.info("XGBoost Parameters:");
        for (Map.Entry entry : hashMap.entrySet()) {
            Log.info(" " + ((String) entry.getKey()) + " = " + entry.getValue());
        }
        Log.info("");
        return BoosterParms.fromMap(Collections.unmodifiableMap(hashMap));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public XGBoostModel deepClone(Key<XGBoostModel> key) {
        XGBoostModel xGBoostModel = (XGBoostModel) IcedUtils.deepCopy(this);
        xGBoostModel._key = key;
        ((XGBoostOutput) xGBoostModel._output).clearModelMetrics(false);
        ((XGBoostOutput) xGBoostModel._output)._training_metrics = null;
        ((XGBoostOutput) xGBoostModel._output)._validation_metrics = null;
        return xGBoostModel;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v28, types: [int] */
    static int getMaxNThread() {
        if (System.getProperty(PROP_NTHREAD) != null) {
            return Integer.getInteger(PROP_NTHREAD).intValue();
        }
        short s = 1;
        HashSet hashSet = new HashSet();
        for (H2ONode h2ONode : H2O.CLOUD.members()) {
            String ip = h2ONode.getIp();
            if (!hashSet.contains(ip)) {
                hashSet.add(ip);
                long count = Stream.of((Object[]) H2O.CLOUD.members()).filter(h2ONode2 -> {
                    return h2ONode2.getIp().equals(ip);
                }).count();
                if (count > s) {
                    s = (int) count;
                }
            }
        }
        return Math.max(1, H2O.ARGS.nthreads / s);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public AutoBuffer writeAll_impl(AutoBuffer autoBuffer) {
        autoBuffer.putKey(this.model_info.getDataInfoKey());
        return super.writeAll_impl(autoBuffer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Keyed readAll_impl(AutoBuffer autoBuffer, Futures futures) {
        autoBuffer.getKey(this.model_info.getDataInfoKey(), futures);
        return super.readAll_impl(autoBuffer, futures);
    }

    @Override // hex.Model
    public XGBoostMojoWriter getMojo() {
        return new XGBoostMojoWriter(this);
    }

    private ModelMetrics makeMetrics(Frame frame, Frame frame2, boolean z, String str) {
        Log.debug("Making metrics: " + str);
        return new XGBoostModelMetrics((XGBoostOutput) this._output, frame, frame2, z, this).compute();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void doScoring(Frame frame, Frame frame2, Frame frame3, Frame frame4) {
        ModelMetrics makeMetrics = makeMetrics(frame, frame2, true, "Metrics reported on training frame");
        ((XGBoostOutput) this._output)._training_metrics = makeMetrics;
        ((XGBoostOutput) this._output)._scored_train[((XGBoostOutput) this._output)._ntrees].fillFrom(makeMetrics);
        addModelMetrics(makeMetrics);
        if (frame3 != null) {
            ModelMetrics makeMetrics2 = makeMetrics(frame3, frame4, false, "Metrics reported on validation frame");
            ((XGBoostOutput) this._output)._validation_metrics = makeMetrics2;
            ((XGBoostOutput) this._output)._scored_valid[((XGBoostOutput) this._output)._ntrees].fillFrom(makeMetrics2);
            addModelMetrics(makeMetrics2);
        }
    }

    @Override // hex.Model
    protected Frame postProcessPredictions(Frame frame, Frame frame2, Job job) {
        return PlattScalingHelper.postProcessPredictions(frame2, job, (PlattScalingHelper.OutputWithCalibration) this._output);
    }

    @Override // hex.Model
    protected double[] score0(double[] dArr, double[] dArr2) {
        return score0(dArr, dArr2, 0.0d);
    }

    @Override // hex.Model
    public double[] score0(double[] dArr, double[] dArr2, double d) {
        float[] predict;
        DataInfo dataInfo = this.model_info.dataInfo();
        if (!$assertionsDisabled && dataInfo == null) {
            throw new AssertionError();
        }
        MutableOneHotEncoderFVec mutableOneHotEncoderFVec = new MutableOneHotEncoderFVec(dataInfo, ((XGBoostOutput) this._output)._sparse);
        mutableOneHotEncoderFVec.setInput(dArr);
        Predictor makePredictor = PredictorFactory.makePredictor(this.model_info._boosterBytes);
        if (((XGBoostOutput) this._output).hasOffset()) {
            predict = makePredictor.predict(mutableOneHotEncoderFVec, (float) d);
        } else {
            if (d != 0.0d) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            predict = makePredictor.predict(mutableOneHotEncoderFVec);
        }
        return XGBoostMojoModel.toPreds(dArr, predict, dArr2, ((XGBoostOutput) this._output).nclasses(), ((XGBoostOutput) this._output)._priorClassDist, defaultThreshold());
    }

    @Override // hex.Model
    protected XGBoostBigScorePredict setupBigScorePredict(Model<XGBoostModel, XGBoostParameters, XGBoostOutput>.BigScore bigScore) {
        return setupBigScorePredict(false);
    }

    public XGBoostBigScorePredict setupBigScorePredict(boolean z) {
        DataInfo scoringInfo = model_info().scoringInfo(z);
        return PredictConfiguration.useJavaScoring() ? setupBigScorePredictJava(scoringInfo) : setupBigScorePredictNative(scoringInfo);
    }

    private XGBoostBigScorePredict setupBigScorePredictNative(DataInfo dataInfo) {
        return new XGBoostNativeBigScorePredict(this.model_info, (XGBoostParameters) this._parms, (XGBoostOutput) this._output, dataInfo, createParams((XGBoostParameters) this._parms, ((XGBoostOutput) this._output).nclasses(), dataInfo.coefNames()), defaultThreshold());
    }

    private XGBoostBigScorePredict setupBigScorePredictJava(DataInfo dataInfo) {
        return new XGBoostJavaBigScorePredict(this.model_info, (XGBoostOutput) this._output, dataInfo, (XGBoostParameters) this._parms, defaultThreshold());
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key) {
        return scoreContributions(frame, key, null);
    }

    @Override // hex.Model.Contributions
    public Frame scoreContributions(Frame frame, Key<Frame> key, Job<Frame> job) {
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        DataInfo dataInfo = model_info().dataInfo();
        if (!$assertionsDisabled && dataInfo == null) {
            throw new AssertionError();
        }
        String[] strArr = (String[]) ArrayUtils.append((Object[]) dataInfo.coefNames(), (Object[]) new String[]{"BiasTerm"});
        return new PredictTreeSHAPTask(dataInfo, model_info(), (XGBoostOutput) this._output).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(strArr.length, (byte) 3, frame2).outputFrame(key, strArr, (String[][]) null);
    }

    @Override // hex.Model.LeafNodeAssignment
    public Frame scoreLeafNodeAssignment(Frame frame, Model.LeafNodeAssignment.LeafNodeAssignmentType leafNodeAssignmentType, Key<Frame> key) {
        AssignLeafNodeTask make = AssignLeafNodeTask.make(this.model_info.scoringInfo(false), (XGBoostOutput) this._output, this.model_info._boosterBytes, leafNodeAssignmentType);
        Frame frame2 = new Frame(frame);
        adaptTestForTrain(frame2, true, false);
        return make.execute(frame2, key);
    }

    private void setDataInfoToOutput(DataInfo dataInfo) {
        ((XGBoostOutput) this._output).setNames(dataInfo._adaptedFrame.names(), dataInfo._adaptedFrame.typesStr());
        ((XGBoostOutput) this._output)._domains = dataInfo._adaptedFrame.domains();
        ((XGBoostOutput) this._output)._nums = dataInfo._nums;
        ((XGBoostOutput) this._output)._cats = dataInfo._cats;
        ((XGBoostOutput) this._output)._catOffsets = dataInfo._catOffsets;
        ((XGBoostOutput) this._output)._useAllFactorLevels = dataInfo._useAllFactorLevels;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model, water.Keyed
    public Futures remove_impl(Futures futures, boolean z) {
        DataInfo dataInfo = model_info().dataInfo();
        if (dataInfo != null) {
            dataInfo.remove(futures);
        }
        if (((XGBoostOutput) this._output)._calib_model != null) {
            ((XGBoostOutput) this._output)._calib_model.remove(futures);
        }
        return super.remove_impl(futures, z);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeGraphConverter
    public SharedTreeGraph convert(int i, String str) {
        GradBooster booster = XGBoostJavaMojoModel.makePredictor(this.model_info._boosterBytes).getBooster();
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException("XGBoost model is not backed by a tree-based booster. Booster class is " + booster.getClass().getCanonicalName());
        }
        RegTree[][] groupedTrees = ((GBTree) booster).getGroupedTrees();
        int xGBoostClassIndex = getXGBoostClassIndex(str);
        if (xGBoostClassIndex >= groupedTrees.length) {
            throw new IllegalArgumentException(String.format("Given XGBoost model does not have given class '%s'.", str));
        }
        RegTree[] regTreeArr = groupedTrees[xGBoostClassIndex];
        if (i >= regTreeArr.length || i < 0) {
            throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", Integer.valueOf(regTreeArr.length)));
        }
        RegTreeNode[] nodes = regTreeArr[i].getNodes();
        if (!$assertionsDisabled && nodes.length < 1) {
            throw new AssertionError();
        }
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        SharedTreeSubgraph makeSubgraph = sharedTreeGraph.makeSubgraph(((XGBoostOutput) this._output)._training_metrics._description);
        constructSubgraph(nodes, makeSubgraph.makeRootNode(), 0, makeSubgraph, XGBoostUtils.assembleFeatureNames(this.model_info.dataInfo()), true);
        return sharedTreeGraph;
    }

    private static void constructSubgraph(RegTreeNode[] regTreeNodeArr, SharedTreeNode sharedTreeNode, int i, SharedTreeSubgraph sharedTreeSubgraph, XGBoostUtils.FeatureProperties featureProperties, boolean z) {
        RegTreeNode regTreeNode = regTreeNodeArr[i];
        if (featureProperties._oneHotEncoded[regTreeNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(regTreeNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(regTreeNode.getLeafValue());
        sharedTreeNode.setInclusiveNa(z);
        sharedTreeNode.setNodeNumber(i);
        if (regTreeNode.isLeaf()) {
            return;
        }
        sharedTreeNode.setCol(regTreeNode.getSplitIndex(), featureProperties._names[regTreeNode.getSplitIndex()]);
        constructSubgraph(regTreeNodeArr, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), regTreeNode.getLeftChildIndex(), sharedTreeSubgraph, featureProperties, regTreeNode.default_left());
        constructSubgraph(regTreeNodeArr, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), regTreeNode.getRightChildIndex(), sharedTreeSubgraph, featureProperties, !regTreeNode.default_left());
    }

    @Override // hex.genmodel.algos.tree.SharedTreeGraphConverter
    public SharedTreeGraph convert(int i, String str, ConvertTreeOptions convertTreeOptions) {
        return convert(i, str);
    }

    private int getXGBoostClassIndex(String str) {
        ModelCategory modelCategory = ((XGBoostOutput) this._output).getModelCategory();
        if (ModelCategory.Regression.equals(modelCategory) && str != null && !str.isEmpty()) {
            throw new IllegalArgumentException("There should be no tree class specified for regression.");
        }
        if (str == null || str.isEmpty()) {
            switch (modelCategory) {
                case Binomial:
                case Regression:
                    return 0;
                default:
                    throw new IllegalArgumentException(String.format("Model category '%s' requires tree class to be specified.", modelCategory));
            }
        }
        String[] strArr = ((XGBoostOutput) this._output)._domains[((XGBoostOutput) this._output)._domains.length - 1];
        int find = ArrayUtils.find(strArr, str);
        if (ModelCategory.Binomial.equals(modelCategory) && find != 0) {
            throw new IllegalArgumentException(String.format("For binomial XGBoost model, only one tree for class %s has been built.", strArr[0]));
        }
        if (find < 0) {
            throw new IllegalArgumentException(String.format("No such class '%s' in tree.", str));
        }
        return find;
    }

    @Override // hex.Model
    public boolean isFeatureUsedInPredict(String str) {
        int find = ArrayUtils.find(((XGBoostOutput) this._output)._varimp._names, str);
        if (find != -1 || ((XGBoostOutput) this._output)._catOffsets.length <= 1) {
            return (find == -1 || ((double) ((XGBoostOutput) this._output)._varimp._varimp[find]) == 0.0d) ? false : true;
        }
        int find2 = ArrayUtils.find(((XGBoostOutput) this._output)._names, str);
        if (find2 == -1 || !((XGBoostOutput) this._output)._column_types[find2].equals("Enum")) {
            return false;
        }
        for (int i = 0; i < ((XGBoostOutput) this._output)._varimp._names.length; i++) {
            if (((XGBoostOutput) this._output)._varimp._names[i].startsWith(str.concat(".")) && ((XGBoostOutput) this._output)._varimp._varimp[i] != 0.0f) {
                return true;
            }
        }
        return false;
    }

    @Override // hex.Model
    protected boolean toJavaCheckTooBig() {
        return this._output == 0 || ((XGBoostOutput) this._output)._ntrees * ((XGBoostParameters) this._parms)._max_depth > 1000;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.Model
    public SBPrintStream toJavaInit(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline) {
        sBPrintStream.nl();
        sBPrintStream.ip("public boolean isSupervised() { return true; }").nl();
        sBPrintStream.ip("public int nclasses() { return ").p(((XGBoostOutput) this._output).nclasses()).p("; }").nl();
        return sBPrintStream;
    }

    @Override // hex.Model
    protected void toJavaPredictBody(SBPrintStream sBPrintStream, CodeGeneratorPipeline codeGeneratorPipeline, CodeGeneratorPipeline codeGeneratorPipeline2, boolean z) {
        XGBoostPojoWriter.make(PredictorFactory.makePredictor(this.model_info._boosterBytes, false), JCodeGen.toJavaId(this._key.toString()), (XGBoostOutput) this._output, defaultThreshold()).renderJavaPredictBody(sBPrintStream, codeGeneratorPipeline2);
    }

    @Override // hex.Model
    protected /* bridge */ /* synthetic */ Model.BigScorePredict setupBigScorePredict(Model.BigScore bigScore) {
        return setupBigScorePredict((Model<XGBoostModel, XGBoostParameters, XGBoostOutput>.BigScore) bigScore);
    }

    static {
        $assertionsDisabled = !XGBoostModel.class.desiredAssertionStatus();
    }
}
