package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.io.File;
import java.util.Map;
import water.H2O;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;
import water.util.Log;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostSetupTask.class */
public class XGBoostSetupTask extends AbstractXGBoostTask<XGBoostSetupTask> {
    private final XGBoostModelInfo _sharedModel;
    private final XGBoostModel.XGBoostParameters _parms;
    private final boolean _sparse;
    private final BoosterParms _boosterParms;
    private final byte[] _checkpoint;
    private final IcedHashMapGeneric.IcedHashMapStringString _rabitEnv;
    private final Frame _trainFrame;

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostSetupTask$FrameNodes.class */
    public static class FrameNodes {
        final Frame _fr;
        final boolean[] _nodes;
        final int _numNodes;

        private FrameNodes(Frame frame, boolean[] zArr) {
            this._fr = frame;
            this._nodes = zArr;
            int i = 0;
            for (boolean z : this._nodes) {
                if (z) {
                    i++;
                }
            }
            this._numNodes = i;
        }

        public int getNumNodes() {
            return this._numNodes;
        }
    }

    public XGBoostSetupTask(XGBoostModel xGBoostModel, XGBoostModel.XGBoostParameters xGBoostParameters, BoosterParms boosterParms, byte[] bArr, Map<String, String> map, FrameNodes frameNodes) {
        super(xGBoostModel._key, frameNodes._nodes);
        this._sharedModel = xGBoostModel.model_info();
        this._parms = xGBoostParameters;
        this._sparse = ((XGBoostOutput) xGBoostModel._output)._sparse;
        this._boosterParms = boosterParms;
        this._checkpoint = bArr;
        IcedHashMapGeneric.IcedHashMapStringString icedHashMapStringString = new IcedHashMapGeneric.IcedHashMapStringString();
        this._rabitEnv = icedHashMapStringString;
        icedHashMapStringString.putAll(map);
        this._trainFrame = frameNodes._fr;
    }

    @Override // ml.dmlc.xgboost4j.java.AbstractXGBoostTask
    protected void execute() {
        try {
            DMatrix makeLocalMatrix = makeLocalMatrix();
            if (this._parms._save_matrix_directory != null) {
                File file = new File(this._parms._save_matrix_directory);
                if (file.mkdirs()) {
                    Log.debug("Created directory for matrix export: " + file.getAbsolutePath());
                }
                File file2 = new File(file, "matrix.part" + H2O.SELF.index());
                Log.info("Saving node-local portion of XGBoost training dataset to " + file2.getAbsolutePath() + ".");
                makeLocalMatrix.saveBinary(file2.getAbsolutePath());
            }
            if (makeLocalMatrix == null) {
                throw new IllegalStateException("Node " + H2O.SELF + " is supposed to participate in XGB training but it doesn't have a DMatrix!");
            }
            this._rabitEnv.put("DMLC_TASK_ID", String.valueOf(H2O.SELF.index()));
            XGBoostUpdater.make(this._modelKey, makeLocalMatrix, this._boosterParms, this._checkpoint, this._rabitEnv).start();
        } catch (XGBoostError e) {
            throw new IllegalStateException("Failed XGBoost training.", e);
        }
    }

    private DMatrix makeLocalMatrix() throws XGBoostError {
        return XGBoostUtils.convertFrameToDMatrix(this._sharedModel.dataInfo(), this._trainFrame, this._parms._response_column, this._parms._weights_column, this._parms._offset_column, this._sparse);
    }

    public static FrameNodes findFrameNodes(Frame frame) {
        boolean[] zArr = new boolean[H2O.CLOUD.size()];
        Vec anyVec = frame.anyVec();
        for (int i = 0; i < anyVec.nChunks(); i++) {
            int index = anyVec.chunkKey(i).home_node().index();
            if (!zArr[index]) {
                zArr[index] = true;
            }
        }
        return new FrameNodes(frame, zArr);
    }
}
