package hex.tree.xgboost.remote;

import hex.genmodel.utils.IOUtils;
import hex.schemas.XGBoostExecReqV3;
import hex.schemas.XGBoostExecRespV3;
import hex.tree.xgboost.exec.LocalXGBoostExecutor;
import hex.tree.xgboost.exec.XGBoostExecReq;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import org.apache.log4j.Logger;
import water.BootstrapFreezable;
import water.H2O;
import water.Iced;
import water.TypeMap;
import water.api.Handler;
import water.api.StreamingSchema;

/* loaded from: input_file:hex/tree/xgboost/remote/RemoteXGBoostHandler.class */
public class RemoteXGBoostHandler extends Handler {
    private static final Logger LOG = Logger.getLogger((Class<?>) RemoteXGBoostHandler.class);

    /* loaded from: input_file:hex/tree/xgboost/remote/RemoteXGBoostHandler$RemoteExecutors.class */
    public static class RemoteExecutors extends Iced<RemoteExecutors> implements BootstrapFreezable<RemoteExecutors> {
        public final String[] _nodes;
        public final String[] _typeMap = TypeMap.bootstrapClasses();

        public RemoteExecutors(String[] strArr) {
            this._nodes = strArr;
        }
    }

    private XGBoostExecRespV3 makeResponse(LocalXGBoostExecutor localXGBoostExecutor) {
        return new XGBoostExecRespV3(localXGBoostExecutor.modelKey);
    }

    public XGBoostExecRespV3 init(int i, XGBoostExecReqV3 xGBoostExecReqV3) {
        LocalXGBoostExecutor localXGBoostExecutor = new LocalXGBoostExecutor(xGBoostExecReqV3.key.key(), (XGBoostExecReq.Init) xGBoostExecReqV3.readData());
        XGBoostExecutorRegistry.storeExecutor(localXGBoostExecutor);
        return new XGBoostExecRespV3(localXGBoostExecutor.modelKey, collectNodes());
    }

    private RemoteExecutors collectNodes() {
        String[] strArr = new String[H2O.CLOUD.size()];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = H2O.CLOUD.members()[i].getIpPortString();
        }
        return new RemoteExecutors(strArr);
    }

    public StreamingSchema setup(int i, XGBoostExecReqV3 xGBoostExecReqV3) {
        return streamBytes(XGBoostExecutorRegistry.getExecutor(xGBoostExecReqV3).setup());
    }

    public XGBoostExecRespV3 update(int i, XGBoostExecReqV3 xGBoostExecReqV3) {
        LocalXGBoostExecutor executor = XGBoostExecutorRegistry.getExecutor(xGBoostExecReqV3);
        executor.update(((XGBoostExecReq.Update) xGBoostExecReqV3.readData()).treeId);
        return makeResponse(executor);
    }

    public StreamingSchema getBooster(int i, XGBoostExecReqV3 xGBoostExecReqV3) {
        return streamBytes(XGBoostExecutorRegistry.getExecutor(xGBoostExecReqV3).updateBooster());
    }

    public XGBoostExecRespV3 cleanup(int i, XGBoostExecReqV3 xGBoostExecReqV3) {
        LocalXGBoostExecutor executor = XGBoostExecutorRegistry.getExecutor(xGBoostExecReqV3);
        executor.close();
        XGBoostExecutorRegistry.removeExecutor(executor);
        return makeResponse(executor);
    }

    private StreamingSchema streamBytes(byte[] bArr) {
        byte[] bArr2 = bArr == null ? new byte[0] : bArr;
        return new StreamingSchema((outputStream, streamWriteOptionArr) -> {
            try {
                IOUtils.copyStream(new ByteArrayInputStream(bArr2), outputStream);
            } catch (IOException e) {
                LOG.error("Failed writing data to response.", e);
                throw new RuntimeException("Failed writing data to response.", e);
            }
        });
    }
}
