package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.XGBoostExtension;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import ml.dmlc.xgboost4j.java.PersistentDMatrix;
import water.ExtensionManager;
import water.H2O;
import water.MRTask;
import water.util.IcedHashMapGeneric;
import water.util.Log;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdateTask.class */
public class XGBoostUpdateTask extends MRTask<XGBoostUpdateTask> {
    private final IcedHashMapGeneric.IcedHashMapStringObject _nodeToMatrixWrapper;
    private final XGBoostOutput _output;
    private transient Booster _booster;
    private byte[] _rawBooster;
    private final XGBoostModel.XGBoostParameters _parms;
    private final int _tid;
    private IcedHashMapGeneric.IcedHashMapStringString rabitEnv = new IcedHashMapGeneric.IcedHashMapStringString();

    public XGBoostUpdateTask(XGBoostSetupTask xGBoostSetupTask, Booster booster, XGBoostOutput xGBoostOutput, XGBoostModel.XGBoostParameters xGBoostParameters, int i, Map<String, String> map) {
        this._nodeToMatrixWrapper = xGBoostSetupTask._nodeToMatrixWrapper;
        this._output = xGBoostOutput;
        this._parms = xGBoostParameters;
        this._tid = i;
        this._rawBooster = hex.tree.xgboost.XGBoost.getRawArray(booster);
        this.rabitEnv.putAll(map);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.MRTask
    public void setupLocal() {
        if (H2O.ARGS.client) {
            return;
        }
        if (!ExtensionManager.getInstance().isCoreExtensionEnabled(XGBoostExtension.NAME)) {
            throw new IllegalStateException("XGBoost is not available on the node " + H2O.SELF);
        }
        try {
            PersistentDMatrix.Wrapper wrapper = (PersistentDMatrix.Wrapper) this._nodeToMatrixWrapper.get(H2O.SELF.toString());
            if (wrapper == null) {
                return;
            }
            update(wrapper.get());
        } catch (XGBoostError e) {
            try {
                Rabit.shutdown();
            } catch (XGBoostError e2) {
                e2.printStackTrace();
            }
            e.printStackTrace();
            throw new IllegalStateException("Failed XGBoost training.", e);
        }
    }

    private void update(DMatrix dMatrix) throws XGBoostError {
        HashMap<String, Object> createParams = XGBoostModel.createParams(this._parms, this._output);
        this.rabitEnv.put("DMLC_TASK_ID", String.valueOf(H2O.SELF.index()));
        try {
            Rabit.init(this.rabitEnv);
            if (this._rawBooster == null) {
                this._booster = XGBoost.train(dMatrix, createParams, 0, new HashMap(), null, null);
            } else {
                try {
                    this._booster = Booster.loadModel(new ByteArrayInputStream(this._rawBooster));
                    Log.debug("Booster created from bytes, raw size = " + this._rawBooster.length);
                    this._booster.setParams(createParams);
                    this._booster.update(dMatrix, this._tid);
                } catch (IOException e) {
                    e.printStackTrace();
                    throw new IllegalStateException("Failed to load the booster.", e);
                }
            }
            this._rawBooster = this._booster.toByteArray();
            try {
                Rabit.shutdown();
            } catch (XGBoostError e2) {
                Log.debug("Rabit shutdown during update failed", e2);
            }
        } catch (Throwable th) {
            try {
                Rabit.shutdown();
            } catch (XGBoostError e3) {
                Log.debug("Rabit shutdown during update failed", e3);
            }
            throw th;
        }
    }

    @Override // water.MRTask
    public void reduce(XGBoostUpdateTask xGBoostUpdateTask) {
        if (null == this._rawBooster) {
            this._rawBooster = xGBoostUpdateTask._rawBooster;
        }
    }

    public Booster getBooster() {
        if (null == this._booster) {
            try {
                this._booster = Booster.loadModel(new ByteArrayInputStream(this._rawBooster));
                Log.debug("Booster created from bytes, raw size = " + this._rawBooster.length);
            } catch (IOException | XGBoostError e) {
                throw new IllegalStateException("Failed to load the booster.", e);
            }
        }
        return this._booster;
    }
}
