package water.rapids.ast.prims.advmath;

import java.util.Random;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstStratifiedSplit.class */
public class AstStratifiedSplit extends AstPrimitive {
    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"ary", "test_frac", "seed"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "h2o.random_stratified_split";
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v5, types: [java.lang.String[], java.lang.String[][]] */
    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame frame = stackHelp.track(astRootArr[1].exec(env)).getFrame();
        if (frame.numCols() != 1) {
            throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + frame.numCols() + " columns.");
        }
        Vec anyVec = frame.anyVec();
        if (!anyVec.isCategorical() && (!anyVec.isNumeric() || !anyVec.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + anyVec.get_type_str());
        }
        final double num = astRootArr[2].exec(env).getNum();
        long num2 = (long) astRootArr[3].exec(env).getNum();
        long nextLong = num2 == -1 ? new Random().nextLong() : num2;
        final long[] domain = new VecUtils.CollectDomain().doAll(anyVec).domain();
        final int length = anyVec.isNumeric() ? domain.length : anyVec.domain().length;
        final long[] jArr = new long[length];
        for (int i = 0; i < length; i++) {
            jArr[i] = RandomUtils.getRNG(nextLong + i).nextLong();
        }
        return new ValFrame(new MRTask() { // from class: water.rapids.ast.prims.advmath.AstStratifiedSplit.1
            private boolean isTest(int i2, long j) {
                return RandomUtils.getRNG(((long) i2) + j).nextDouble() <= num;
            }

            @Override // water.MRTask
            public void map(Chunk chunk, NewChunk newChunk) {
                int start = (int) chunk.start();
                for (int i2 = 0; i2 < length; i2++) {
                    for (int i3 = 0; i3 < chunk._len; i3++) {
                        if (chunk.at8(i3) == (domain == null ? i2 : domain[i2])) {
                            if (isTest(start + i3, jArr[i2])) {
                                newChunk.addNum(1L, 0);
                            } else {
                                newChunk.addNum(0L, 0);
                            }
                        }
                    }
                }
            }
        }.doAll(1, (byte) 3, new Frame(anyVec)).outputFrame(new String[]{"test_train_split"}, new String[]{new String[]{"train", "test"}}));
    }
}
