package hex;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import water.DKV;
import water.H2O;
import water.Iced;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.TAtomic;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/DMatrix.class */
public class DMatrix {
    static int cnt = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/DMatrix$GetNonZerosTsk.class */
    public static class GetNonZerosTsk extends MRTask<GetNonZerosTsk> {
        final int _maxsz;
        int[] _idxs;
        double[] _vals;
        static final /* synthetic */ boolean $assertionsDisabled;

        public GetNonZerosTsk(H2O.H2OCountedCompleter h2OCountedCompleter) {
            super(h2OCountedCompleter);
            this._maxsz = 10000000;
        }

        public GetNonZerosTsk(H2O.H2OCountedCompleter h2OCountedCompleter, int i) {
            super(h2OCountedCompleter);
            this._maxsz = i;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            int start = (int) chunk.start();
            if (!$assertionsDisabled && chunk.start() + chunk._len != start + chunk._len) {
                throw new AssertionError();
            }
            int sparseLenZero = chunk.sparseLenZero();
            this._idxs = MemoryManager.malloc4(sparseLenZero);
            this._vals = MemoryManager.malloc8d(sparseLenZero);
            int i = 0;
            int nextNZ = chunk.nextNZ(-1);
            while (nextNZ < chunk._len) {
                this._idxs[i] = nextNZ + start;
                this._vals[i] = chunk.atd(nextNZ);
                nextNZ = chunk.nextNZ(nextNZ);
                i++;
            }
            if (!$assertionsDisabled && i != sparseLenZero) {
                throw new AssertionError();
            }
            if (this._idxs.length > this._maxsz) {
                throw new RuntimeException("too many nonzeros! found at least " + this._idxs.length + " nonzeros.");
            }
        }

        @Override // water.MRTask
        public void reduce(GetNonZerosTsk getNonZerosTsk) {
            if (this._idxs.length + getNonZerosTsk._idxs.length > this._maxsz) {
                throw new RuntimeException("too many nonzeros! found at least " + (this._idxs.length + getNonZerosTsk._idxs.length) + " nonzeros.");
            }
            int[] malloc4 = MemoryManager.malloc4(this._idxs.length + getNonZerosTsk._idxs.length);
            double[] malloc8d = MemoryManager.malloc8d(this._vals.length + getNonZerosTsk._vals.length);
            ArrayUtils.sortedMerge(this._idxs, this._vals, getNonZerosTsk._idxs, getNonZerosTsk._vals, malloc4, malloc8d);
            this._idxs = malloc4;
            this._vals = malloc8d;
        }

        static {
            $assertionsDisabled = !DMatrix.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/DMatrix$MatrixMulStats.class */
    public static class MatrixMulStats extends Iced {
        public final Key jobKey;
        public final long chunksTotal;
        public long lastUpdateAt;
        public long chunksDone;
        public long size;
        public int[] chunkTypes = new int[0];
        public long[] chunkCnts = new long[0];
        public final long _startTime = System.currentTimeMillis();

        public MatrixMulStats(long j, Key key) {
            this.chunksTotal = j;
            this.jobKey = key;
        }

        public float progress() {
            return (float) (this.chunksDone / this.chunksTotal);
        }
    }

    /* loaded from: input_file:hex/DMatrix$MatrixMulTsk.class */
    public static class MatrixMulTsk extends H2O.H2OCountedCompleter {
        final transient Frame _x;
        Frame _y;
        Frame _z;
        final Key _progressKey;
        AtomicInteger _cntr;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/DMatrix$MatrixMulTsk$Callback.class */
        public class Callback extends H2O.H2OCallback {
            public Callback() {
                super(MatrixMulTsk.this);
            }

            @Override // water.H2O.H2OCallback
            public void callback(H2O.H2OCountedCompleter h2OCountedCompleter) {
                int incrementAndGet = MatrixMulTsk.this._cntr.incrementAndGet();
                if (incrementAndGet < MatrixMulTsk.this._y.numCols()) {
                    MatrixMulTsk.this.forkVecTask(incrementAndGet);
                }
            }
        }

        public MatrixMulTsk(H2O.H2OCountedCompleter h2OCountedCompleter, Key key, Frame frame, Frame frame2) {
            super(h2OCountedCompleter);
            if (frame.numCols() != frame2.numRows()) {
                throw new IllegalArgumentException("dimensions do not match! x.numcols = " + frame.numCols() + ", y.numRows = " + frame2.numRows());
            }
            this._x = frame;
            this._y = frame2;
            this._progressKey = key;
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            this._z = new Frame(this._x.anyVec().makeZeros(this._y.numCols()));
            int size = (256 * (H2O.CLOUD.size() * H2O.NUMCPUS)) / this._y.anyVec().nChunks();
            Log.info("maxP = " + size);
            this._cntr = new AtomicInteger(size - 1);
            addToPendingCount((2 * this._y.numCols()) - 1);
            for (int i = 0; i < Math.min(this._y.numCols(), size); i++) {
                forkVecTask(i);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void forkVecTask(final int i) {
            new GetNonZerosTsk(new H2O.H2OCallback<GetNonZerosTsk>(this) { // from class: hex.DMatrix.MatrixMulTsk.1
                @Override // water.H2O.H2OCallback
                public void callback(GetNonZerosTsk getNonZerosTsk) {
                    new VecTsk(new Callback(), MatrixMulTsk.this._progressKey, getNonZerosTsk._vals).dfork((Vec[]) ArrayUtils.append(MatrixMulTsk.this._x.vecs(getNonZerosTsk._idxs), MatrixMulTsk.this._z.vec(i)));
                }
            }).dfork(this._y.vec(i));
        }
    }

    /* loaded from: input_file:hex/DMatrix$TransposeTsk.class */
    public static class TransposeTsk extends MRTask<TransposeTsk> {
        final Frame _tgt;

        public TransposeTsk(Frame frame) {
            this._tgt = frame;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            Frame frame = this._tgt;
            long[] espc = frame.anyVec().espc();
            int start = (int) chunkArr[0].start();
            for (int i = 0; i < espc.length - 1; i++) {
                int i2 = i;
                NewChunk[] newChunkArr = new NewChunk[chunkArr[0]._len];
                for (int i3 = 0; i3 < newChunkArr.length; i3++) {
                    newChunkArr[i3] = new NewChunk(frame.vec(i3 + start), i2);
                }
                for (int i4 = (int) espc[i2]; i4 < ((int) espc[i2 + 1]); i4++) {
                    Chunk chunk = chunkArr[i4];
                    if (chunk.isSparseZero()) {
                        int nextNZ = chunk.nextNZ(-1);
                        while (true) {
                            int i5 = nextNZ;
                            if (i5 < chunk._len) {
                                newChunkArr[i5].addZeros(((int) (i4 - espc[i2])) - newChunkArr[i5]._len);
                                chunk.extractRows(newChunkArr[i5], i5);
                                nextNZ = chunk.nextNZ(i5);
                            }
                        }
                    } else {
                        for (int i6 = 0; i6 < chunk._len; i6++) {
                            newChunkArr[i6].addZeros(((int) (i4 - espc[i2])) - newChunkArr[i6]._len);
                            chunk.extractRows(newChunkArr[i6], i6);
                        }
                    }
                }
                for (int i7 = 0; i7 < newChunkArr.length; i7++) {
                    int i8 = i7;
                    newChunkArr[i8].addZeros(((int) (espc[i2 + 1] - espc[i2])) - newChunkArr[i8]._len);
                    newChunkArr[i8].close(this._fs);
                    newChunkArr[i8] = null;
                }
            }
        }
    }

    /* loaded from: input_file:hex/DMatrix$UpdateProgress.class */
    private static class UpdateProgress extends TAtomic<MatrixMulStats> {
        final int _chunkSz;
        final int _chunkType;

        public UpdateProgress(int i, int i2) {
            this._chunkSz = i;
            this._chunkType = i2;
        }

        @Override // water.TAtomic
        public MatrixMulStats atomic(MatrixMulStats matrixMulStats) {
            matrixMulStats.chunkCnts = (long[]) matrixMulStats.chunkCnts.clone();
            int i = -1;
            int i2 = 0;
            while (true) {
                if (i2 >= matrixMulStats.chunkTypes.length) {
                    break;
                }
                if (this._chunkType == matrixMulStats.chunkTypes[i2]) {
                    i = i2;
                    break;
                }
                i2++;
            }
            if (i == -1) {
                matrixMulStats.chunkTypes = Arrays.copyOf(matrixMulStats.chunkTypes, matrixMulStats.chunkTypes.length + 1);
                matrixMulStats.chunkCnts = Arrays.copyOf(matrixMulStats.chunkCnts, matrixMulStats.chunkCnts.length + 1);
                matrixMulStats.chunkTypes[matrixMulStats.chunkTypes.length - 1] = this._chunkType;
                i = matrixMulStats.chunkTypes.length - 1;
            }
            matrixMulStats.chunksDone++;
            long[] jArr = matrixMulStats.chunkCnts;
            int i3 = i;
            jArr[i3] = jArr[i3] + 1;
            matrixMulStats.lastUpdateAt = System.currentTimeMillis();
            matrixMulStats.size += this._chunkSz;
            return matrixMulStats;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/DMatrix$VecTsk.class */
    public static class VecTsk extends MRTask<VecTsk> {
        double[] _y;
        Key _progressKey;

        public VecTsk(H2O.H2OCountedCompleter h2OCountedCompleter, Key key, double[] dArr) {
            super(h2OCountedCompleter);
            this._progressKey = key;
            this._y = dArr;
        }

        @Override // water.MRTask
        public void setupLocal() {
            this._fr.lastVec().preWriting();
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            Chunk chunk = chunkArr[chunkArr.length - 1];
            double[] malloc8d = MemoryManager.malloc8d(chunkArr[0]._len);
            for (int i = 0; i < this._y.length; i++) {
                double d = this._y[i];
                Chunk chunk2 = chunkArr[i];
                int nextNZ = chunk2.nextNZ(-1);
                while (true) {
                    int i2 = nextNZ;
                    if (i2 < malloc8d.length) {
                        malloc8d[i2] = malloc8d[i2] + (d * chunk2.atd(i2));
                        nextNZ = chunk2.nextNZ(i2);
                    }
                }
            }
            Chunk compress = new NewChunk(malloc8d).setSparseRatio(2).compress();
            if (this._progressKey != null) {
                new UpdateProgress(compress.getBytes().length, compress.frozenType()).fork(this._progressKey);
            }
            DKV.put(chunk.vec().chunkKey(chunk.cidx()), compress, this._fs);
        }

        @Override // water.MRTask
        public void closeLocal() {
            this._y = null;
            this._progressKey = null;
        }
    }

    public static Frame transpose(Frame frame) {
        if (frame.numRows() != ((int) frame.numRows())) {
            throw H2O.unimpl();
        }
        int max = Math.max(1, frame.numCols() / 10000);
        long[] jArr = new long[max + 1];
        int numCols = frame.numCols() / max;
        int numCols2 = frame.numCols() % max;
        Arrays.fill(jArr, numCols);
        for (int i = 0; i < numCols2; i++) {
            int i2 = i;
            jArr[i2] = jArr[i2] + 1;
        }
        long j = 0;
        for (int i3 = 0; i3 < jArr.length; i3++) {
            long j2 = jArr[i3];
            jArr[i3] = j;
            j += j2;
        }
        Key<Vec> newKey = Vec.newKey();
        return transpose(frame, new Frame(new Vec(newKey, Vec.ESPC.rowLayout(newKey, jArr)).makeZeros((int) frame.numRows())));
    }

    public static Frame transpose(Frame frame, Frame frame2) {
        if (frame.numRows() != frame2.numCols() || frame.numCols() != frame2.numRows()) {
            throw new IllegalArgumentException("dimension do not match!");
        }
        for (Vec vec : frame.vecs()) {
            if (vec.isCategorical()) {
                throw new IllegalArgumentException("transpose can only be applied to all-numeric frames (representing a matrix)");
            }
            if (vec.length() > 1000000) {
                throw new IllegalArgumentException("too many rows, transpose only works for frames with < 1M rows.");
            }
        }
        new TransposeTsk(frame2).doAll(frame);
        return frame2;
    }

    public static Frame mmul(Frame frame, Frame frame2) {
        MatrixMulTsk matrixMulTsk = new MatrixMulTsk(null, null, frame, frame2);
        if (Thread.currentThread() instanceof H2O.FJWThr) {
            matrixMulTsk.fork().join();
        } else {
            ((MatrixMulTsk) H2O.submitTask(matrixMulTsk)).join();
        }
        return matrixMulTsk._z;
    }
}
