package io.prestosql.sql.query;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.sql.planner.Plan;
import io.prestosql.sql.planner.assertions.PlanMatchPattern;
import io.prestosql.sql.planner.optimizations.PlanNodeSearcher;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import java.util.function.Consumer;
import org.intellij.lang.annotations.Language;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/sql/query/TestSubqueries.class */
public class TestSubqueries {
    private static final String UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG = "line .*: Given correlated subquery is not supported";
    private QueryAssertions assertions;

    @BeforeClass
    public void init() {
        this.assertions = new QueryAssertions();
    }

    @AfterClass(alwaysRun = true)
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test
    public void testCorrelatedExistsSubqueriesWithOrPredicateAndNull() {
        assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES null, 10) t(x) WHERE y > x OR y + 10 > x) FROM (values 11 + if(rand() >= 0, 0)) t2(y)", "VALUES true", false);
        assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES null) t(x) WHERE y > x OR y + 10 > x) FROM (VALUES 11 + if(rand() >= 0, 0)) t2(y)", "VALUES false", false);
    }

    @Test
    public void testUnsupportedSubqueriesWithCoercions() {
        this.assertions.assertFails("SELECT EXISTS(SELECT 1 FROM (VALUES (1, null)) t(a, b) WHERE t.a=t2.b GROUP BY t.b) FROM (VALUES 1.0, 2.0) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertions.assertFails("SELECT EXISTS(SELECT 1 FROM (VALUES (null, null)) t(a, b) WHERE t.a=t2.b GROUP BY t.b) FROM (VALUES 1, 2) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testCorrelatedSubqueriesWithLimit() {
        this.assertions.assertQuery("SELECT (SELECT t.a FROM (VALUES 1, 2) t(a) WHERE t.a=t2.b LIMIT 1) FROM (VALUES 1) t2(b)", "VALUES 1");
        this.assertions.assertQuery("SELECT (SELECT t.a FROM (VALUES 1, 2) t(a) WHERE t.a=t2.b LIMIT 2) FROM (VALUES 1) t2(b)", "VALUES 1");
        this.assertions.assertFails("SELECT (SELECT t.a FROM (VALUES 1, 2, 3) t(a) WHERE t.a=t2.b LIMIT 2) from (VALUES 1) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        this.assertions.assertQuery("SELECT (SELECT sum(t.a) FROM (VALUES 1, 2) t(a) WHERE t.a=t2.b group by t.a LIMIT 2) FROM (VALUES 1) t2(b)", "VALUES BIGINT '1'");
        this.assertions.assertQuery("SELECT (SELECT count(*) FROM (SELECT t.a FROM (VALUES 1, 1, null, 3) t(a) LIMIT 1) t WHERE t.a=t2.b) FROM (VALUES 1, 2) t2(b)", "VALUES BIGINT '1', BIGINT '0'");
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (VALUES 1, 1, 3) t(a) WHERE t.a=t2.b LIMIT 1) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", false);
        this.assertions.assertQuery("SELECT (SELECT count(*) FROM (VALUES 1, 1, 3) t(a) WHERE t.a=t2.b LIMIT 1) FROM (VALUES 1) t2(b)", "VALUES BIGINT '2'");
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (VALUES ('x', 1)) u(x, cid) WHERE x = 'x' AND t.cid = cid LIMIT 1) FROM (VALUES 1) t(cid)", "VALUES true", false);
    }

    @Test
    public void testCorrelatedSubqueriesWithGroupBy() {
        this.assertions.assertFails("SELECT (SELECT count(*) FROM (VALUES 1, 2, 3, null) t(a) WHERE t.a<t2.b GROUP BY t.a) FROM (VALUES 1, 2, 3) t2(b)", "Scalar sub-query has returned multiple rows");
        this.assertions.assertQuery("SELECT (SELECT count(*) FROM (VALUES 1, 1, 2, 3, null) t(a) WHERE t.a<t2.b GROUP BY t.a HAVING count(*) > 1) FROM (VALUES 1, 2) t2(b)", "VALUES null, BIGINT '2'");
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (VALUES 1, 1, 3) t(a) WHERE t.a=t2.b GROUP BY t.a) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", false);
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (VALUES (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) WHERE t.a=t2.b GROUP BY t.a, t.b) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", true);
        assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) WHERE t.a<t2.b GROUP BY t.a, t.b) FROM (VALUES 1, 2) t2(b)", "VALUES false, true", true);
        this.assertions.assertFails("SELECT EXISTS(SELECT 1 FROM (VALUES (1, 1), (1, 1), (null, null), (3, 3)) t(a, b) WHERE t.a+t.b<t2.b GROUP BY t.a) FROM (VALUES 1, 2) t2(b)", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
        assertExistsRewrittenToAggregationAboveJoin("SELECT EXISTS(SELECT 1 FROM (VALUES (1, 1), (1, 1), (null, null), (3, 3)) t(a, b) WHERE t.a+t.b<t2.b GROUP BY t.a, t.b) FROM (VALUES 1, 4) t2(b)", "VALUES false, true", true);
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT 1 FROM (VALUES (1, 2), (1, 2), (null, null), (3, 3)) t(a, b) WHERE t.a=t2.b GROUP BY t.b) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", true);
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT * FROM (VALUES 1, 1, 2, 3) t(a) WHERE t.a=t2.b GROUP BY t.a HAVING count(*) > 1) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", false);
        this.assertions.assertQuery("SELECT EXISTS(SELECT * FROM (SELECT t.a FROM (VALUES (1, 1), (1, 1), (1, 2), (1, 2), (3, 3)) t(a, b) WHERE t.b=t2.b GROUP BY t.a HAVING count(*) > 1) t WHERE t.a=t2.b) FROM (VALUES 1, 2) t2(b)", "VALUES true, false");
        assertExistsRewrittenToAggregationBelowJoin("SELECT EXISTS(SELECT * FROM (VALUES 1, 1, 2, 3) t(a) WHERE t.a=t2.b GROUP BY (t.a) HAVING count(*) > 1) FROM (VALUES 1, 2) t2(b)", "VALUES true, false", false);
    }

    @Test
    public void testCorrelatedLateralWithGroupBy() {
        this.assertions.assertQuery("SELECT * FROM (VALUES 1, 2) t2(b), LATERAL (SELECT t.a FROM (VALUES 1, 1, 3) t(a) WHERE t.a=t2.b GROUP BY t.a)", "VALUES (1, 1)");
        this.assertions.assertQuery("SELECT * FROM (VALUES 1, 2) t2(b), LATERAL (SELECT count(*) FROM (VALUES 1, 1, 2, 3) t(a) WHERE t.a=t2.b GROUP BY t.a HAVING count(*) > 1)", "VALUES (1, BIGINT '2')");
        this.assertions.assertFails("SELECT * FROM (VALUES 1, 2) t2(b), LATERAL (SELECT t.a, t.b, count(*) FROM (VALUES (1, 1), (1, 2), (2, 2), (3, 3)) t(a, b) WHERE t.a=t2.b GROUP BY GROUPING SETS ((t.a, t.b), (t.a)))", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testLateralWithUnnest() {
        this.assertions.assertFails("SELECT * FROM (VALUES ARRAY[1]) t(x), LATERAL (SELECT * FROM UNNEST(x))", UNSUPPORTED_CORRELATED_SUBQUERY_ERROR_MSG);
    }

    @Test
    public void testCorrelatedScalarSubquery() {
        this.assertions.assertQuery("SELECT * FROM (VALUES 1, 2) t2(b) WHERE (SELECT b) = 2", "VALUES 2");
    }

    @Test
    public void testCorrelatedSubqueryWithExplicitCoercion() {
        this.assertions.assertQuery("SELECT 1 FROM (VALUES 1, 2) t1(b) WHERE 1 = (SELECT cast(b as decimal(7,2)))", "VALUES 1");
    }

    private void assertExistsRewrittenToAggregationBelowJoin(@Language("SQL") String str, @Language("SQL") String str2, boolean z) {
        PlanMatchPattern node = PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0]);
        if (z) {
            node = PlanMatchPattern.aggregation(ImmutableMap.of(), PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.aggregation(ImmutableMap.of(), PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])))));
        }
        this.assertions.assertQueryAndPlan(str, str2, PlanMatchPattern.anyTree(PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])), PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(ImmutableMap.of(), AggregationNode.Step.FINAL, PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, PlanMatchPattern.aggregation(ImmutableMap.of(), AggregationNode.Step.PARTIAL, PlanMatchPattern.anyTree(node))))))), plan -> {
            Assert.assertEquals(countFinalAggregationNodes(plan), z ? 2 : 1);
        });
    }

    private void assertExistsRewrittenToAggregationAboveJoin(@Language("SQL") String str, @Language("SQL") String str2, boolean z) {
        Consumer consumer = plan -> {
            Assert.assertEquals(countSingleStreamingAggregations(plan), 1);
        };
        this.assertions.assertQueryAndPlan(str, str2, PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(ImmutableMap.of("COUNT", PlanMatchPattern.functionCall("count", ImmutableList.of("NON_NULL"))), AggregationNode.Step.SINGLE, PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0])), PlanMatchPattern.anyTree(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.anyTree(PlanMatchPattern.node(ValuesNode.class, new PlanMatchPattern[0]))).withAlias("NON_NULL", PlanMatchPattern.expression("true")))))), consumer.andThen(plan2 -> {
            Assert.assertEquals(countFinalAggregationNodes(plan2), z ? 1 : 0);
        }));
    }

    private static int countFinalAggregationNodes(Plan plan) {
        return PlanNodeSearcher.searchFrom(plan.getRoot()).where(planNode -> {
            return (planNode instanceof AggregationNode) && ((AggregationNode) planNode).getStep() == AggregationNode.Step.FINAL;
        }).count();
    }

    private static int countSingleStreamingAggregations(Plan plan) {
        return PlanNodeSearcher.searchFrom(plan.getRoot()).where(planNode -> {
            return (planNode instanceof AggregationNode) && ((AggregationNode) planNode).getStep() == AggregationNode.Step.SINGLE && ((AggregationNode) planNode).isStreamable();
        }).count();
    }
}
