package water.rapids.ast.prims.reducers;

import java.io.Serializable;
import java.util.PriorityQueue;
import water.MRTask;
import water.fvec.C8Chunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;

/* loaded from: input_file:water/rapids/ast/prims/reducers/AstTopN.class */
public class AstTopN extends AstPrimitive {

    /* loaded from: input_file:water/rapids/ast/prims/reducers/AstTopN$GrabTopNPQ.class */
    public class GrabTopNPQ<E extends Comparable<E>> extends MRTask<GrabTopNPQ<E>> {
        final String[] _columnName;
        PriorityQueue _sortQueue;
        Frame _sortedOut;
        final int _rowSize;
        final boolean _increasing;
        boolean _csLong;

        private GrabTopNPQ(String[] strArr, long j, boolean z) {
            this._csLong = false;
            this._columnName = strArr;
            this._rowSize = (int) j;
            this._increasing = z;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            this._sortQueue = new PriorityQueue();
            this._csLong = chunk instanceof C8Chunk;
            Long valueOf = Long.valueOf(chunk.start());
            for (int i = 0; i < chunk._len; i++) {
                long longValue = i + valueOf.longValue();
                if (!chunk.isNA(i)) {
                    addOneValue(chunk, i, longValue, this._sortQueue);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(GrabTopNPQ<E> grabTopNPQ) {
            this._sortQueue.addAll(grabTopNPQ._sortQueue);
            int size = this._sortQueue.size() - this._rowSize;
            if (size > 0) {
                for (int i = 0; i < size; i++) {
                    this._sortQueue.poll();
                }
            }
        }

        @Override // water.MRTask
        public void postGlobal() {
            Vec[] vecArr = new Vec[2];
            long min = StrictMath.min(this._rowSize, this._sortQueue.size());
            for (int i = 0; i < vecArr.length; i++) {
                vecArr[i] = Vec.makeZero(min);
            }
            for (int i2 = 0; i2 < min; i2++) {
                RowValue rowValue = (RowValue) this._sortQueue.poll();
                vecArr[0].set(i2, rowValue.getRow().longValue());
                vecArr[1].set(i2, this._csLong ? ((Long) rowValue.getValue()).longValue() : ((Double) rowValue.getValue()).doubleValue());
            }
            this._sortedOut = new Frame(this._columnName, vecArr);
        }

        public void addOneValue(Chunk chunk, int i, long j, PriorityQueue priorityQueue) {
            RowValue rowValue;
            if (this._csLong) {
                rowValue = new RowValue(Long.valueOf(j), Long.valueOf(chunk.at8(i)), this._increasing);
            } else {
                rowValue = new RowValue(Long.valueOf(j), Double.valueOf(chunk.atd(i)), this._increasing);
            }
            priorityQueue.offer(rowValue);
            if (priorityQueue.size() > this._rowSize) {
                priorityQueue.poll();
            }
        }
    }

    /* loaded from: input_file:water/rapids/ast/prims/reducers/AstTopN$RowValue.class */
    public class RowValue<E extends Comparable<E>> implements Comparable<RowValue<E>>, Serializable {
        private Long _rowIndex;
        private E _value;
        boolean _increasing;

        public RowValue(Long l, E e, boolean z) {
            this._rowIndex = l;
            this._value = e;
            this._increasing = z;
        }

        public E getValue() {
            return this._value;
        }

        public Long getRow() {
            return this._rowIndex;
        }

        @Override // java.lang.Comparable
        public int compareTo(RowValue<E> rowValue) {
            return getValue().compareTo(rowValue.getValue()) * (this._increasing ? 1 : -1);
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"frame", "col", "nPercent", "getBottomN"};
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "topn";
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 5;
    }

    @Override // water.rapids.ast.AstPrimitive, water.rapids.ast.AstRoot
    public String example() {
        return "(topn frame col nPercent getBottomN)";
    }

    @Override // water.rapids.ast.AstPrimitive, water.rapids.ast.AstRoot
    public String description() {
        return "Return the top N percent rows for a numerical column as a frame with two columns.  The first column will contain the original row indices of the chosen values.  The second column contains the top N rowvalues.  If getBottomN is 1, we will return the bottom N percent.  If getBottomN is 0, we will returnthe top N percent of rows";
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame frame = stackHelp.track(astRootArr[1].exec(env)).getFrame();
        int num = (int) stackHelp.track(astRootArr[2].exec(env)).getNum();
        double num2 = stackHelp.track(astRootArr[3].exec(env)).getNum();
        int num3 = (int) stackHelp.track(astRootArr[4].exec(env)).getNum();
        frame.numCols();
        GrabTopNPQ grabTopNPQ = new GrabTopNPQ(new String[]{"Original_Row_Indices", frame.name(num)}, Math.round(num2 * 0.01d * frame.numRows()), num3 == 0);
        grabTopNPQ.doAll(frame.vec(num));
        return new ValFrame(grabTopNPQ._sortedOut);
    }
}
