package hex.tree.xgboost.matrix;

import ai.h2o.xgboost4j.java.DMatrix;
import ai.h2o.xgboost4j.java.XGBoostError;
import ai.h2o.xgboost4j.java.util.BigDenseMatrix;
import hex.DataInfo;
import hex.tree.xgboost.matrix.MatrixLoader;
import java.util.Objects;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import water.H2O;
import water.LocalMR;
import water.MrFun;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:hex/tree/xgboost/matrix/DenseMatrixFactory.class */
public class DenseMatrixFactory {
    private static final Logger LOG;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/xgboost/matrix/DenseMatrixFactory$DenseDMatrixProvider.class */
    public static class DenseDMatrixProvider extends MatrixLoader.DMatrixProvider {
        private BigDenseMatrix data;

        /* JADX INFO: Access modifiers changed from: protected */
        public DenseDMatrixProvider(long j, float[] fArr, float[] fArr2, float[] fArr3, BigDenseMatrix bigDenseMatrix) {
            super(j, fArr, fArr2, fArr3);
            this.data = bigDenseMatrix;
        }

        @Override // hex.tree.xgboost.matrix.MatrixLoader.DMatrixProvider
        public void print(int i) {
            int i2 = 0;
            while (true) {
                if (i2 >= (i > 0 ? i : this.data.nrow)) {
                    return;
                }
                System.out.print(i2 + ParameterizedMessage.ERROR_MSG_SEPARATOR);
                for (int i3 = 0; i3 < this.data.ncol; i3++) {
                    System.out.print(this.data.get(i2, i3) + ", ");
                }
                System.out.print(this.response[i2]);
                System.out.println();
                i2++;
            }
        }

        @Override // hex.tree.xgboost.matrix.MatrixLoader.DMatrixProvider
        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass() || !super.equals(obj)) {
                return false;
            }
            DenseDMatrixProvider denseDMatrixProvider = (DenseDMatrixProvider) obj;
            if (denseDMatrixProvider.data.ncol != this.data.ncol || denseDMatrixProvider.data.nrow != this.data.nrow) {
                return false;
            }
            for (int i = 0; i < this.data.nrow; i++) {
                for (int i2 = 0; i2 < this.data.ncol; i2++) {
                    if (this.data.get(i, i2) != denseDMatrixProvider.data.get(i, i2)) {
                        return false;
                    }
                }
            }
            return true;
        }

        @Override // hex.tree.xgboost.matrix.MatrixLoader.DMatrixProvider
        public int hashCode() {
            return super.hashCode() + Objects.hash(this.data);
        }

        @Override // hex.tree.xgboost.matrix.MatrixLoader.DMatrixProvider
        public DMatrix makeDMatrix() throws XGBoostError {
            return new DMatrix(this.data, Float.NaN);
        }

        @Override // hex.tree.xgboost.matrix.MatrixLoader.DMatrixProvider
        protected void dispose() {
            if (this.data != null) {
                this.data.dispose();
                this.data = null;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/matrix/DenseMatrixFactory$WriteDenseChunkFun.class */
    public static class WriteDenseChunkFun extends MrFun<WriteDenseChunkFun> {
        private final Frame _f;
        private final int[] _chunks;
        private final int[] _rowOffsets;
        private final Vec _weightsVec;
        private final Vec _offsetsVec;
        private final Vec _respVec;
        private final DataInfo _di;
        private final BigDenseMatrix _data;
        private final float[] _resp;
        private final float[] _weights;
        private final float[] _offsets;
        private final int[] _nRowsByChunk;
        static final /* synthetic */ boolean $assertionsDisabled;

        private WriteDenseChunkFun(Frame frame, int[] iArr, int[] iArr2, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, BigDenseMatrix bigDenseMatrix, float[] fArr, float[] fArr2, float[] fArr3) {
            this._f = frame;
            this._chunks = iArr;
            this._rowOffsets = iArr2;
            this._weightsVec = vec;
            this._offsetsVec = vec2;
            this._respVec = vec3;
            this._di = dataInfo;
            this._data = bigDenseMatrix;
            this._resp = fArr;
            this._weights = fArr2;
            this._offsets = fArr3;
            this._nRowsByChunk = new int[iArr.length];
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MrFun
        public void map(int i) {
            int i2 = this._chunks[i];
            Chunk[] chunkArr = new Chunk[this._f.numCols()];
            for (int i3 = 0; i3 < chunkArr.length; i3++) {
                chunkArr[i3] = this._f.vec(i3).chunkForChunkIdx(i2);
            }
            Chunk chunkForChunkIdx = this._weightsVec != null ? this._weightsVec.chunkForChunkIdx(i2) : null;
            Chunk chunkForChunkIdx2 = this._offsetsVec != null ? this._offsetsVec.chunkForChunkIdx(i2) : null;
            Chunk chunkForChunkIdx3 = this._respVec.chunkForChunkIdx(i2);
            long j = this._rowOffsets[i] * this._data.ncol;
            int i4 = 0;
            for (int i5 = 0; i5 < chunkArr[0]._len; i5++) {
                if (chunkForChunkIdx == null || chunkForChunkIdx.atd(i5) != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    j = DenseMatrixFactory.writeDenseRow(this._di, chunkArr, i5, this._data, j);
                    this._resp[this._rowOffsets[i] + i4] = (float) chunkForChunkIdx3.atd(i5);
                    if (chunkForChunkIdx != null) {
                        this._weights[this._rowOffsets[i] + i4] = (float) chunkForChunkIdx.atd(i5);
                    }
                    if (chunkForChunkIdx2 != null) {
                        this._offsets[this._rowOffsets[i] + i4] = (float) chunkForChunkIdx2.atd(i5);
                    }
                    i4++;
                }
            }
            if (!$assertionsDisabled && j != this._rowOffsets[i + 1] * this._data.ncol) {
                throw new AssertionError();
            }
            this._nRowsByChunk[i] = i4;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int getTotalRows() {
            int i = 0;
            for (int i2 : this._nRowsByChunk) {
                i += i2;
            }
            return i;
        }

        static {
            $assertionsDisabled = !DenseMatrixFactory.class.desiredAssertionStatus();
        }
    }

    public static DMatrix dense(Chunk[] chunkArr, DataInfo dataInfo, int i, float[] fArr, float[] fArr2, int i2, float[] fArr3) throws XGBoostError {
        LOG.debug("Treating matrix as dense.");
        BigDenseMatrix bigDenseMatrix = null;
        try {
            BigDenseMatrix allocateDenseMatrix = allocateDenseMatrix(chunkArr[0].len(), dataInfo);
            long denseChunk = denseChunk(allocateDenseMatrix, chunkArr, i, dataInfo, fArr, fArr2, i2, fArr3);
            if (!$assertionsDisabled && denseChunk != allocateDenseMatrix.nrow) {
                throw new AssertionError();
            }
            DMatrix dMatrix = new DMatrix(allocateDenseMatrix, Float.NaN);
            if (allocateDenseMatrix != null) {
                allocateDenseMatrix.dispose();
            }
            return dMatrix;
        } catch (Throwable th) {
            if (0 != 0) {
                bigDenseMatrix.dispose();
            }
            throw th;
        }
    }

    public static DenseDMatrixProvider dense(Frame frame, int[] iArr, int i, int[] iArr2, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, float[] fArr, float[] fArr2, float[] fArr3) {
        BigDenseMatrix bigDenseMatrix = null;
        try {
            BigDenseMatrix allocateDenseMatrix = allocateDenseMatrix(i, dataInfo);
            int denseChunk = denseChunk(allocateDenseMatrix, iArr, iArr2, frame, vec, vec2, vec3, dataInfo, fArr, fArr2, fArr3);
            if ($assertionsDisabled || allocateDenseMatrix.nrow == denseChunk) {
                return new DenseDMatrixProvider(denseChunk, fArr, fArr2, fArr3, allocateDenseMatrix);
            }
            throw new AssertionError();
        } catch (Exception e) {
            if (0 != 0) {
                bigDenseMatrix.dispose();
            }
            throw new RuntimeException("Error while create off-heap matrix.", e);
        }
    }

    private static int denseChunk(BigDenseMatrix bigDenseMatrix, int[] iArr, int[] iArr2, Frame frame, Vec vec, Vec vec2, Vec vec3, DataInfo dataInfo, float[] fArr, float[] fArr2, float[] fArr3) {
        int[] iArr3 = new int[iArr2.length + 1];
        for (int i = 0; i < iArr.length; i++) {
            iArr3[i + 1] = iArr2[i] + iArr3[i];
        }
        WriteDenseChunkFun writeDenseChunkFun = new WriteDenseChunkFun(frame, iArr, iArr3, vec, vec2, vec3, dataInfo, bigDenseMatrix, fArr, fArr2, fArr3);
        ((LocalMR) H2O.submitTask(new LocalMR(writeDenseChunkFun, iArr.length))).join();
        return writeDenseChunkFun.getTotalRows();
    }

    private static long denseChunk(BigDenseMatrix bigDenseMatrix, Chunk[] chunkArr, int i, DataInfo dataInfo, float[] fArr, float[] fArr2, int i2, float[] fArr3) {
        long j = 0;
        long j2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < chunkArr[0]._len; i4++) {
            j = writeDenseRow(dataInfo, chunkArr, i4, bigDenseMatrix, j);
            j2++;
            i3 = MatrixFactoryUtils.setResponseAndWeightAndOffset(chunkArr, i, -1, i2, fArr, fArr2, fArr3, i3, i4);
        }
        if ($assertionsDisabled || bigDenseMatrix.nrow * bigDenseMatrix.ncol == j) {
            return j2;
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static long writeDenseRow(DataInfo dataInfo, Chunk[] chunkArr, int i, BigDenseMatrix bigDenseMatrix, long j) {
        for (int i2 = 0; i2 < dataInfo._cats; i2++) {
            int i3 = dataInfo._catOffsets[i2 + 1] - dataInfo._catOffsets[i2];
            int categoricalId = dataInfo.getCategoricalId(i2, chunkArr[i2].isNA(i) ? Double.NaN : chunkArr[i2].at8(i)) - dataInfo._catOffsets[i2];
            for (int i4 = 0; i4 < i3; i4++) {
                bigDenseMatrix.set(j + i4, 0.0f);
            }
            bigDenseMatrix.set(j + categoricalId, 1.0f);
            j += i3;
        }
        for (int i5 = 0; i5 < dataInfo._nums; i5++) {
            long j2 = j;
            j = j2 + 1;
            bigDenseMatrix.set(j2, chunkArr[dataInfo._cats + i5].isNA(i) ? Float.NaN : (float) chunkArr[dataInfo._cats + i5].atd(i));
        }
        return j;
    }

    private static BigDenseMatrix allocateDenseMatrix(int i, DataInfo dataInfo) {
        return new BigDenseMatrix(i, dataInfo.fullN());
    }

    static {
        $assertionsDisabled = !DenseMatrixFactory.class.desiredAssertionStatus();
        LOG = Logger.getLogger((Class<?>) DenseMatrixFactory.class);
    }
}
