package hex.tree.xgboost.task;

import hex.DataInfo;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostModelInfo;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import hex.tree.xgboost.exec.XGBoostHttpClient;
import hex.tree.xgboost.matrix.MatrixFactoryUtils;
import hex.tree.xgboost.matrix.SparseMatrixDimensions;
import hex.tree.xgboost.matrix.SparseMatrixFactory;
import hex.tree.xgboost.remote.RemoteXGBoostUploadServlet;
import java.util.Optional;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.log4j.Logger;
import water.BootstrapFreezable;
import water.H2O;
import water.Iced;
import water.LocalMR;
import water.MemoryManager;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.VecUtils;

/* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask.class */
public class XGBoostUploadMatrixTask extends AbstractXGBoostTask<XGBoostUploadMatrixTask> {
    private static final Logger LOG;
    private final String[] remoteNodes;
    private final boolean https;
    private final String contextPath;
    private final String userName;
    private final String password;
    private final Frame train;
    private final XGBoostModelInfo modelInfo;
    private final XGBoostModel.XGBoostParameters parms;
    private final boolean sparse;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$DenseMatrixChunk.class */
    public static class DenseMatrixChunk extends Iced<DenseMatrixChunk> implements BootstrapFreezable<DenseMatrixChunk> {
        public final int id;
        public final float[] data;

        DenseMatrixChunk(int i, int i2) {
            this.id = i;
            this.data = new float[i2];
        }
    }

    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$DenseMatrixDimensions.class */
    public static class DenseMatrixDimensions extends Iced<DenseMatrixDimensions> implements BootstrapFreezable<DenseMatrixDimensions> {
        public final int rows;
        public final int cols;
        public final int[] rowOffsets;

        public DenseMatrixDimensions(int i, int i2, int[] iArr) {
            this.rows = i;
            this.cols = i2;
            this.rowOffsets = iArr;
        }
    }

    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$MatrixData.class */
    public static class MatrixData extends Iced<MatrixData> implements BootstrapFreezable<MatrixData> {
        public final float[] resp;
        public final float[] weights;
        public final float[] offsets;
        public int actualRows;
        public int shape;

        MatrixData(int i, Vec vec, Vec vec2) {
            this.resp = MemoryManager.malloc4f(i);
            if (vec != null) {
                this.weights = MemoryManager.malloc4f(i);
            } else {
                this.weights = null;
            }
            if (vec2 != null) {
                this.offsets = MemoryManager.malloc4f(i);
            } else {
                this.offsets = null;
            }
        }
    }

    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$SparseMatrixChunk.class */
    public static class SparseMatrixChunk extends Iced<SparseMatrixChunk> implements BootstrapFreezable<SparseMatrixChunk> {
        public final int id;
        public final long[] rowHeader;
        public final float[] data;
        public final int[] colIndices;

        SparseMatrixChunk(int i, int i2, int i3) {
            this.id = i;
            this.rowHeader = new long[i2];
            this.data = new float[i3];
            this.colIndices = new int[i3];
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$UploadDenseChunkFun.class */
    public class UploadDenseChunkFun extends MrFun<UploadDenseChunkFun> {
        private final Frame _f;
        private final int[] _chunks;
        private final int[] _rowOffsets;
        private final Vec _weightsVec;
        private final Vec _offsetsVec;
        private final Vec _respVec;
        private final DataInfo _di;
        private final float[] _resp;
        private final float[] _weights;
        private final float[] _offsets;
        private final int[] _nRowsByChunk;
        static final /* synthetic */ boolean $assertionsDisabled;

        private UploadDenseChunkFun(Frame frame, int[] iArr, int[] iArr2, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, float[] fArr, float[] fArr2, float[] fArr3) {
            this._f = frame;
            this._chunks = iArr;
            this._rowOffsets = iArr2;
            this._weightsVec = vec;
            this._offsetsVec = vec2;
            this._respVec = vec3;
            this._di = dataInfo;
            this._resp = fArr;
            this._weights = fArr2;
            this._offsets = fArr3;
            this._nRowsByChunk = new int[iArr.length];
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MrFun
        public void map(int i) {
            int i2 = this._chunks[i];
            Chunk[] chunkArr = new Chunk[this._f.numCols()];
            for (int i3 = 0; i3 < chunkArr.length; i3++) {
                chunkArr[i3] = this._f.vec(i3).chunkForChunkIdx(i2);
            }
            Chunk chunkForChunkIdx = this._weightsVec != null ? this._weightsVec.chunkForChunkIdx(i2) : null;
            Chunk chunkForChunkIdx2 = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(i2) : null;
            Chunk chunkForChunkIdx3 = this._respVec.chunkForChunkIdx(i2);
            int i4 = 0;
            DenseMatrixChunk denseMatrixChunk = new DenseMatrixChunk(i, (this._rowOffsets[i + 1] - this._rowOffsets[i]) * this._di.fullN());
            int i5 = 0;
            for (int i6 = 0; i6 < chunkArr[0]._len; i6++) {
                if (chunkForChunkIdx == null || chunkForChunkIdx.atd(i6) != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    i4 = writeDenseRow(this._di, chunkArr, i6, denseMatrixChunk.data, i4);
                    this._resp[this._rowOffsets[i] + i5] = (float) chunkForChunkIdx3.atd(i6);
                    if (chunkForChunkIdx != null) {
                        this._weights[this._rowOffsets[i] + i5] = (float) chunkForChunkIdx.atd(i6);
                    }
                    if (chunkForChunkIdx2 != null) {
                        this._offsets[this._rowOffsets[i] + i5] = (float) chunkForChunkIdx2.atd(i6);
                    }
                    i5++;
                }
            }
            if (!$assertionsDisabled && i4 != denseMatrixChunk.data.length) {
                throw new AssertionError("idx should be " + denseMatrixChunk.data.length + " but it is " + i4);
            }
            this._nRowsByChunk[i] = i5;
            XGBoostUploadMatrixTask.this.makeClient().uploadObject(XGBoostUploadMatrixTask.this._modelKey, RemoteXGBoostUploadServlet.RequestType.denseMatrixChunk, denseMatrixChunk);
        }

        private int writeDenseRow(DataInfo dataInfo, Chunk[] chunkArr, int i, float[] fArr, int i2) {
            for (int i3 = 0; i3 < dataInfo._cats; i3++) {
                int i4 = dataInfo._catOffsets[i3 + 1] - dataInfo._catOffsets[i3];
                fArr[i2 + (dataInfo.getCategoricalId(i3, chunkArr[i3].isNA(i) ? Double.NaN : chunkArr[i3].at8(i)) - dataInfo._catOffsets[i3])] = 1.0f;
                i2 += i4;
            }
            for (int i5 = 0; i5 < dataInfo._nums; i5++) {
                int i6 = i2;
                i2++;
                fArr[i6] = chunkArr[dataInfo._cats + i5].isNA(i) ? Float.NaN : (float) chunkArr[dataInfo._cats + i5].atd(i);
            }
            return i2;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int getTotalRows() {
            int i = 0;
            for (int i2 : this._nRowsByChunk) {
                i += i2;
            }
            return i;
        }

        static {
            $assertionsDisabled = !XGBoostUploadMatrixTask.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/task/XGBoostUploadMatrixTask$UploadSparseMatrixFun.class */
    public class UploadSparseMatrixFun extends MrFun<UploadSparseMatrixFun> {
        Frame _frame;
        int[] _chunks;
        Vec _weightVec;
        Vec _offsetsVec;
        DataInfo _di;
        SparseMatrixDimensions _dims;
        Vec _respVec;
        float[] _resp;
        float[] _weights;
        float[] _offsets;
        int[] _actualRows;
        static final /* synthetic */ boolean $assertionsDisabled;

        UploadSparseMatrixFun(Frame frame, int[] iArr, Vec vec, Vec vec2, DataInfo dataInfo, SparseMatrixDimensions sparseMatrixDimensions, Vec vec3, float[] fArr, float[] fArr2, float[] fArr3) {
            this._actualRows = new int[iArr.length];
            this._frame = frame;
            this._chunks = iArr;
            this._weightVec = vec;
            this._offsetsVec = vec2;
            this._di = dataInfo;
            this._dims = sparseMatrixDimensions;
            this._respVec = vec3;
            this._resp = fArr;
            this._weights = fArr2;
            this._offsets = fArr3;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MrFun
        public void map(int i) {
            int i2;
            long j;
            int i3 = this._chunks[i];
            long j2 = this._dims._precedingNonZeroElementsCounts[i];
            int i4 = this._dims._precedingRowCounts[i];
            if (i == this._dims._precedingNonZeroElementsCounts.length - 1) {
                i2 = this._dims._rowHeadersCount - i4;
                j = this._dims._nonZeroElementsCount - j2;
            } else {
                i2 = (this._dims._precedingRowCounts[i + 1] - i4) + 1;
                j = this._dims._precedingNonZeroElementsCounts[i + 1] - j2;
            }
            if (!$assertionsDisabled && j >= 2147483647L) {
                throw new AssertionError();
            }
            Chunk chunkForChunkIdx = this._weightVec != null ? this._weightVec.chunkForChunkIdx(i3) : null;
            Chunk chunkForChunkIdx2 = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(i3) : null;
            Chunk chunkForChunkIdx3 = this._respVec.chunkForChunkIdx(i3);
            Chunk[] chunkArr = new Chunk[this._frame.vecs().length];
            for (int i5 = 0; i5 < chunkArr.length; i5++) {
                chunkArr[i5] = this._frame.vecs()[i5].chunkForChunkIdx(i3);
            }
            SparseMatrixChunk sparseMatrixChunk = new SparseMatrixChunk(i, i2, (int) j);
            int i6 = 0;
            int i7 = 0;
            for (int i8 = 0; i8 < chunkForChunkIdx3._len; i8++) {
                if (chunkForChunkIdx == null || chunkForChunkIdx.atd(i8) != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    int i9 = i7;
                    i7++;
                    sparseMatrixChunk.rowHeader[i9] = j2;
                    int[] iArr = this._actualRows;
                    iArr[i] = iArr[i] + 1;
                    for (int i10 = 0; i10 < this._di._cats; i10++) {
                        sparseMatrixChunk.data[i6] = 1.0f;
                        if (chunkArr[i10].isNA(i8)) {
                            sparseMatrixChunk.colIndices[i6] = this._di.getCategoricalId(i10, Double.NaN);
                        } else {
                            sparseMatrixChunk.colIndices[i6] = this._di.getCategoricalId(i10, chunkArr[i10].at8(i8));
                        }
                        i6++;
                        j2++;
                    }
                    for (int i11 = 0; i11 < this._di._nums; i11++) {
                        float atd = (float) chunkArr[this._di._cats + i11].atd(i8);
                        if (atd != 0.0f) {
                            sparseMatrixChunk.data[i6] = atd;
                            sparseMatrixChunk.colIndices[i6] = this._di._catOffsets[this._di._catOffsets.length - 1] + i11;
                            i6++;
                            j2++;
                        }
                    }
                    i4 = MatrixFactoryUtils.setResponseWeightAndOffset(chunkForChunkIdx, chunkForChunkIdx2, chunkForChunkIdx3, this._resp, this._weights, this._offsets, i4, i8);
                }
            }
            sparseMatrixChunk.rowHeader[i7] = j2;
            XGBoostUploadMatrixTask.this.makeClient().uploadObject(XGBoostUploadMatrixTask.this._modelKey, RemoteXGBoostUploadServlet.RequestType.sparseMatrixChunk, sparseMatrixChunk);
        }

        static {
            $assertionsDisabled = !XGBoostUploadMatrixTask.class.desiredAssertionStatus();
        }
    }

    public XGBoostUploadMatrixTask(XGBoostModel xGBoostModel, Frame frame, boolean[] zArr, String[] strArr, boolean z, String str, String str2, String str3) {
        super(xGBoostModel._key, zArr);
        this.remoteNodes = strArr;
        this.https = z;
        this.contextPath = str;
        this.userName = str2;
        this.password = str3;
        this.modelInfo = xGBoostModel.model_info();
        this.parms = (XGBoostModel.XGBoostParameters) xGBoostModel._parms;
        this.sparse = ((XGBoostOutput) xGBoostModel._output)._sparse;
        this.train = frame;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public XGBoostHttpClient makeClient() {
        return new XGBoostHttpClient(this.remoteNodes[H2O.SELF.index()] + this.contextPath, this.https, this.userName, this.password);
    }

    @Override // hex.tree.xgboost.task.AbstractXGBoostTask
    protected void execute() {
        XGBoostHttpClient makeClient = makeClient();
        LOG.info("Starting matrix upload for " + this._modelKey);
        long currentTimeMillis = System.currentTimeMillis();
        if (!$assertionsDisabled && this.modelInfo.dataInfo() == null) {
            throw new AssertionError();
        }
        int[] localChunkIds = VecUtils.getLocalChunkIds(this.train.anyVec());
        Vec vec = this.train.vec(this.parms._response_column);
        Vec vec2 = this.train.vec(this.parms._weights_column);
        Vec vec3 = this.train.vec(this.parms._offset_column);
        int[] iArr = new int[localChunkIds.length];
        long sumChunksLength = XGBoostUtils.sumChunksLength(localChunkIds, vec, Optional.ofNullable(vec2), iArr);
        if (sumChunksLength > 2147483647L) {
            throw new IllegalArgumentException("XGBoost currently doesn't support datasets with more than 2147483647 per node. To train a XGBoost model on this dataset add more nodes to your H2O cluster and use distributed training.");
        }
        int i = (int) sumChunksLength;
        MatrixData matrixData = new MatrixData(i, vec2, vec3);
        if (this.sparse) {
            LOG.debug("Treating matrix as sparse.");
            matrixData.shape = this.modelInfo.dataInfo().fullN();
            matrixData.actualRows = csr(makeClient, localChunkIds, vec2, vec3, vec, this.modelInfo.dataInfo(), matrixData.resp, matrixData.weights, matrixData.offsets);
        } else {
            LOG.debug("Treating matrix as dense.");
            matrixData.actualRows = dense(makeClient, localChunkIds, i, iArr, vec2, vec3, vec, this.modelInfo.dataInfo(), matrixData.resp, matrixData.weights, matrixData.offsets);
        }
        makeClient.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.matrixData, matrixData);
        LOG.debug("Matrix upload finished in " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d));
    }

    private int dense(XGBoostHttpClient xGBoostHttpClient, int[] iArr, int i, int[] iArr2, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, float[] fArr, float[] fArr2, float[] fArr3) {
        int[] iArr3 = new int[iArr2.length + 1];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr3[i2 + 1] = iArr2[i2] + iArr3[i2];
        }
        xGBoostHttpClient.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.denseMatrixDimensions, new DenseMatrixDimensions(i, dataInfo.fullN(), iArr3));
        UploadDenseChunkFun uploadDenseChunkFun = new UploadDenseChunkFun(this.train, iArr, iArr3, vec, vec2, vec3, dataInfo, fArr, fArr2, fArr3);
        ((LocalMR) H2O.submitTask(new LocalMR(uploadDenseChunkFun, iArr.length))).join();
        return uploadDenseChunkFun.getTotalRows();
    }

    private int csr(XGBoostHttpClient xGBoostHttpClient, int[] iArr, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, float[] fArr, float[] fArr2, float[] fArr3) {
        SparseMatrixDimensions calculateCSRMatrixDimensions = SparseMatrixFactory.calculateCSRMatrixDimensions(this.train, iArr, vec, dataInfo);
        xGBoostHttpClient.uploadObject(this._modelKey, RemoteXGBoostUploadServlet.RequestType.sparseMatrixDimensions, calculateCSRMatrixDimensions);
        UploadSparseMatrixFun uploadSparseMatrixFun = new UploadSparseMatrixFun(this.train, iArr, vec, vec2, dataInfo, calculateCSRMatrixDimensions, vec3, fArr, fArr2, fArr3);
        ((LocalMR) H2O.submitTask(new LocalMR(uploadSparseMatrixFun, iArr.length))).join();
        return ArrayUtils.sum(uploadSparseMatrixFun._actualRows);
    }

    static {
        $assertionsDisabled = !XGBoostUploadMatrixTask.class.desiredAssertionStatus();
        LOG = Logger.getLogger((Class<?>) XGBoostUploadMatrixTask.class);
    }
}
