package water.rapids;

import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.rapids.vals.ValFrame;
import water.util.Log;

/* loaded from: input_file:water/rapids/GroupByTest.class */
public class GroupByTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testBasic() {
        Frame frame = null;
        try {
            frame = chkTree("(GB hex [1] mean 2 \"all\")", "smalldata/iris/iris_wheader.csv");
            chkDim(frame, 2, 23);
            chkFr(frame, 0, 0, 2.0d);
            chkFr(frame, 1, 0, 3.5d);
            chkFr(frame, 0, 1, 2.2d);
            chkFr(frame, 1, 1, 4.5d);
            chkFr(frame, 0, 7, 2.8d);
            chkFr(frame, 1, 7, 5.042857142857143d);
            chkFr(frame, 0, 22, 4.4d);
            chkFr(frame, 1, 22, 1.5d);
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testCatGroup() {
        Frame frame = null;
        try {
            Frame chkTree = chkTree("(GB hex [4] nrow 0 \"all\" mean 2 \"all\")", "smalldata/iris/iris_wheader.csv");
            chkDim(chkTree, 3, 3);
            chkFr(chkTree, 0, 0, "Iris-setosa");
            chkFr(chkTree, 1, 0, 50.0d);
            chkFr(chkTree, 2, 0, 1.464d);
            chkFr(chkTree, 0, 1, "Iris-versicolor");
            chkFr(chkTree, 1, 1, 50.0d);
            chkFr(chkTree, 2, 1, 4.26d);
            chkFr(chkTree, 0, 2, "Iris-virginica");
            chkFr(chkTree, 1, 2, 50.0d);
            chkFr(chkTree, 2, 2, 5.552d);
            chkTree.delete();
            frame = chkTree("(GB hex [1] mode 4 \"all\" )", "smalldata/iris/iris_wheader.csv");
            chkDim(frame, 2, 23);
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testNAHandle() {
        Frame frame = null;
        try {
            Frame chkTree = chkTree("(GB hex [7] nrow 0 \"all\" mean 1 \"all\")", "smalldata/junit/cars.csv");
            chkDim(chkTree, 3, 13);
            chkFr(chkTree, 0, 0, 70.0d);
            chkFr(chkTree, 1, 0, 35.0d);
            chkFr(chkTree, 2, 0, Double.NaN);
            chkFr(chkTree, 0, 2, 72.0d);
            chkFr(chkTree, 1, 2, 28.0d);
            chkFr(chkTree, 2, 2, 18.714d, 0.1d);
            chkTree.delete();
            Frame chkTree2 = chkTree("(GB hex [7] nrow 1 \"all\" nrow 1 \"rm\" nrow 1 \"ignore\")", "smalldata/junit/cars.csv");
            chkDim(chkTree2, 4, 13);
            chkFr(chkTree2, 0, 0, 70.0d);
            chkFr(chkTree2, 1, 0, 35.0d);
            chkFr(chkTree2, 2, 0, 29.0d);
            chkFr(chkTree2, 3, 0, 29.0d);
            chkTree2.delete();
            frame = chkTree("(GB hex [7] mean 1 \"all\" mean 1 \"rm\" mean 1 \"ignore\")", "smalldata/junit/cars.csv");
            chkDim(frame, 4, 13);
            chkFr(frame, 0, 0, 70.0d);
            chkFr(frame, 1, 0, Double.NaN);
            chkFr(frame, 2, 0, 17.69d, 0.1d);
            chkFr(frame, 3, 0, 14.66d, 0.1d);
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testAllAggs() {
        Frame frame = null;
        try {
            frame = chkTree("(GB hex [4] nrow 0 \"rm\"  mean 1 \"rm\"  sum 1 \"rm\"  min 1 \"rm\"  max 1 \"rm\" )", "smalldata/iris/iris_wheader.csv");
            chkDim(frame, 6, 3);
            chkFr(frame, 0, 0, "Iris-setosa");
            chkFr(frame, 1, 0, 50.0d);
            chkFr(frame, 2, 0, 3.418d);
            chkFr(frame, 3, 0, 170.9d);
            chkFr(frame, 4, 0, 2.3d);
            chkFr(frame, 5, 0, 4.4d);
            chkFr(frame, 0, 1, "Iris-versicolor");
            chkFr(frame, 1, 1, 50.0d);
            chkFr(frame, 2, 1, 2.77d);
            chkFr(frame, 3, 1, 138.5d);
            chkFr(frame, 4, 1, 2.0d);
            chkFr(frame, 5, 1, 3.4d);
            chkFr(frame, 0, 2, "Iris-virginica");
            chkFr(frame, 1, 2, 50.0d);
            chkFr(frame, 2, 2, 2.974d);
            chkFr(frame, 3, 2, 148.7d);
            chkFr(frame, 4, 2, 2.2d);
            chkFr(frame, 5, 2, 3.8d);
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testImpute() {
        Frame frame = null;
        Frame frame2 = null;
        try {
            chkTree("(h2o.impute hex 1 \"mean\" \"low\" [] _ _)", "smalldata/junit/cars.csv", 1.0f);
            Frame frame3 = (Frame) DKV.getGet("hex");
            chkDim(frame3, 8, 406);
            Assert.assertEquals(0L, frame3.vec(1).naCnt());
            Assert.assertEquals(23.51d, frame3.vec(1).at(26L), 0.1d);
            frame3.delete();
            frame2 = chkTree("(h2o.impute hex 1 \"mean\" \"low\" [7] _ _)", "smalldata/junit/cars.csv", 1.0f);
            frame = (Frame) DKV.getGet("hex");
            chkDim(frame, 8, 406);
            Assert.assertEquals(0L, frame.vec(1).naCnt());
            Assert.assertEquals(17.69d, frame.vec(1).at(26L), 0.1d);
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testBasicDdply() {
        Frame frame = null;
        try {
            Frame chkTree = chkTree("(ddply hex [1] {x . (flatten (mean (cols x 2) TRUE))})", "smalldata/iris/iris_wheader.csv");
            chkDim(chkTree, 2, 23);
            chkFr(chkTree, 0, 0, 2.0d);
            chkFr(chkTree, 1, 0, 3.5d);
            chkFr(chkTree, 0, 1, 2.2d);
            chkFr(chkTree, 1, 1, 4.5d);
            chkFr(chkTree, 0, 7, 2.8d);
            chkFr(chkTree, 1, 7, 5.042857142857143d);
            chkFr(chkTree, 0, 22, 4.4d);
            chkFr(chkTree, 1, 22, 1.5d);
            chkTree.delete();
            frame = chkTree("(ddply hex [1] {x . (sum (* (cols x 2) (cols x 3)))})", "smalldata/iris/iris_wheader.csv");
            chkDim(frame, 2, 23);
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testSplitCats() throws InterruptedException {
        Frame parse_test_file = parse_test_file(Key.make("cov"), "smalldata/covtype/covtype.altered.gz");
        System.out.println(parse_test_file.toString(0L, 10));
        Val exec = Rapids.exec("(ddply cov [54] nrow)");
        System.out.println(exec.toString());
        exec.getFrame().delete();
        Val exec2 = Rapids.exec("(GB cov [54] nrow 54 \"all\")");
        System.out.println(exec2.toString());
        exec2.getFrame().delete();
        parse_test_file.delete();
    }

    @Ignore
    public void testGroupbyMedian() {
        Frame frame = null;
        double[] dArr = {0.49851096435701053d, 0.5018318704735285d, 0.5018723436256065d, 0.5052896538751508d, 0.4988730254120379d};
        try {
            frame = chkTree("(GB hex [0] median 1 \"all\")", "smalldata/jira/pubdev_4727_junit_data.csv");
            for (int i = 0; i < frame.numRows(); i++) {
                Assert.assertTrue(Math.abs(dArr[(int) frame.vec(0).at((long) i)] - frame.vec(1).at((long) i)) < 1.0E-12d);
            }
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            Keyed.remove(Key.make("hex"));
            throw th;
        }
    }

    @Test
    public void testGroupbyTableSpeed() {
        Frame parse_test_file = parse_test_file(Key.make("cov"), "smalldata/junit/id_cols.csv");
        parse_test_file.replace(0, parse_test_file.anyVec().toCategoricalVec()).remove();
        System.out.println(parse_test_file.toString(0L, 10));
        long currentTimeMillis = System.currentTimeMillis();
        Val exec = Rapids.exec("(GB cov [0] nrow 0 \"all\")");
        System.out.println("GB Time= " + (System.currentTimeMillis() - currentTimeMillis) + "msec");
        System.out.println(exec.toString());
        exec.getFrame().delete();
        long currentTimeMillis2 = System.currentTimeMillis();
        Val exec2 = Rapids.exec("(table cov FALSE)");
        System.out.println("Table Time= " + (System.currentTimeMillis() - currentTimeMillis2) + "msec");
        System.out.println(exec2.toString());
        exec2.getFrame().delete();
        parse_test_file.delete();
        Keyed.remove(Key.make("cov"));
    }

    @Test
    @Ignore
    public void testPubDev6319() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
            Scope.track(new Frame[]{parse_test_file});
            String format = String.format("(GB %s [%d]  sum %s \"all\" nrow %s \"all\")", parse_test_file._key, Integer.valueOf(parse_test_file.find("home.dest")), Integer.valueOf(parse_test_file.find("survived")), Integer.valueOf(parse_test_file.find("survived")));
            Frame frame = Rapids.exec(format).getFrame();
            Scope.track(new Frame[]{frame});
            for (int i = 0; i < 200; i++) {
                Log.info(new Object[]{"Running attempt: " + i});
                Frame frame2 = Rapids.exec(format).getFrame();
                Scope.track(new Frame[]{frame2});
                Assert.assertArrayEquals(frame.vec(0).domain(), frame.vec(0).domain());
                assertCatVecEquals(frame.vec(0), frame2.vec(0));
                assertVecEquals(frame.vec(1), frame2.vec(1), 0.0d);
                assertVecEquals(frame.vec(2), frame2.vec(2), 0.0d);
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private void chkDim(Frame frame, int i, int i2) {
        Assert.assertEquals(i, frame.numCols());
        Assert.assertEquals(i2, frame.numRows());
    }

    private void chkFr(Frame frame, int i, int i2, double d) {
        chkFr(frame, i, i2, d, Math.ulp(1.0f));
    }

    private void chkFr(Frame frame, int i, int i2, double d, double d2) {
        if (Double.isNaN(d)) {
            Assert.assertTrue(frame.vec(i).isNA(i2));
        } else {
            Assert.assertEquals(d, frame.vec(i).at(i2), d2);
        }
    }

    private void chkFr(Frame frame, int i, int i2, String str) {
        Assert.assertEquals(str, frame.vec(i).domain()[(int) frame.vec(i).at8(i2)]);
    }

    private Frame chkTree(String str, String str2, float f) {
        parse_test_file(Key.make("hex"), str2);
        Val exec = Rapids.exec(str);
        System.out.println(exec.toString());
        if (exec instanceof ValFrame) {
            return exec.getFrame();
        }
        return null;
    }

    private Frame chkTree(String str, String str2) {
        return chkTree(str, str2, false);
    }

    private Frame chkTree(String str, String str2, boolean z) {
        Frame parse_test_file = parse_test_file(Key.make("hex"), str2);
        try {
            Val exec = Rapids.exec(str);
            System.out.println(exec.toString());
            if (exec instanceof ValFrame) {
                return exec.getFrame();
            }
            throw new IllegalArgumentException("expected a frame return");
        } catch (IllegalArgumentException e) {
            if (!z) {
                throw e;
            }
            parse_test_file.delete();
            return null;
        }
    }
}
