package io.prestosql.sql.query;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.connector.MockConnectorFactory;
import io.prestosql.metadata.QualifiedObjectName;
import io.prestosql.plugin.tpch.TpchConnectorFactory;
import io.prestosql.spi.connector.ConnectorViewDefinition;
import io.prestosql.spi.connector.SchemaTableName;
import io.prestosql.spi.security.Identity;
import io.prestosql.spi.security.ViewExpression;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.query.QueryAssertions;
import io.prestosql.testing.LocalQueryRunner;
import io.prestosql.testing.QueryRunner;
import io.prestosql.testing.TestingAccessControlManager;
import io.prestosql.testing.TestingSession;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/prestosql/sql/query/TestRowFilter.class */
public class TestRowFilter {
    private static final String CATALOG = "local";
    private static final String MOCK_CATALOG = "mock";
    private static final String VIEW_OWNER = "view-owner";
    private static final String RUN_AS_USER = "run-as-user";
    private QueryAssertions assertions;
    private TestingAccessControlManager accessControl;
    private static final String USER = "user";
    private static final Session SESSION = TestingSession.testSessionBuilder().setCatalog("local").setSchema("tiny").setIdentity(Identity.forUser(USER).build()).build();

    @BeforeClass
    public void init() {
        LocalQueryRunner build = LocalQueryRunner.builder(SESSION).build();
        build.createCatalog("local", new TpchConnectorFactory(1), ImmutableMap.of());
        ConnectorViewDefinition connectorViewDefinition = new ConnectorViewDefinition("SELECT nationkey, name FROM local.tiny.nation", Optional.empty(), Optional.empty(), ImmutableList.of(new ConnectorViewDefinition.ViewColumn("nationkey", BigintType.BIGINT.getTypeId()), new ConnectorViewDefinition.ViewColumn("name", VarcharType.createVarcharType(25).getTypeId())), Optional.empty(), Optional.of(VIEW_OWNER), false);
        build.createCatalog(MOCK_CATALOG, MockConnectorFactory.builder().withGetViews((connectorSession, schemaTablePrefix) -> {
            return ImmutableMap.builder().put(new SchemaTableName("default", "nation_view"), connectorViewDefinition).build();
        }).build(), ImmutableMap.of());
        this.assertions = new QueryAssertions((QueryRunner) build);
        this.accessControl = this.assertions.getQueryRunner().getAccessControl();
    }

    @AfterClass(alwaysRun = true)
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test
    public void testSimpleFilter() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '7'");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "NULL"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '0'");
    }

    @Test
    public void testMultipleFilters() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey < 10"));
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey > 5"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '2'");
    }

    @Test
    public void testCorrelatedSubquery() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "EXISTS (SELECT 1 FROM nation WHERE nationkey = orderkey)"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '7'");
    }

    @Test
    public void testView() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "nation"), VIEW_OWNER, new ViewExpression(VIEW_OWNER, Optional.empty(), Optional.empty(), "nationkey = 1"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query(Session.builder(SESSION).setIdentity(Identity.forUser(RUN_AS_USER).build()).build(), "SELECT name FROM mock.default.nation_view"))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "nation"), VIEW_OWNER, new ViewExpression(VIEW_OWNER, Optional.of("local"), Optional.of("tiny"), "nationkey = 1"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query(Session.builder(SESSION).setIdentity(Identity.forUser(VIEW_OWNER).build()).build(), "SELECT name FROM mock.default.nation_view"))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "nation"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "nationkey = 1"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query(Session.builder(SESSION).setIdentity(Identity.forUser(RUN_AS_USER).build()).build(), "SELECT count(*) FROM mock.default.nation_view"))).matches("VALUES BIGINT '25'");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "nationkey = 1"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT name FROM mock.default.nation_view"))).matches("VALUES CAST('ARGENTINA' AS VARCHAR(25))");
    }

    @Test
    public void testTableReferenceInWithClause() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.empty(), Optional.empty(), "orderkey = 1"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("WITH t AS (SELECT count(*) FROM orders) SELECT * FROM t"))).matches("VALUES BIGINT '1'");
    }

    @Test
    public void testOtherSchema() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("sf1"), "(SELECT count(*) FROM customer) = 150000"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '15000'");
    }

    @Test
    public void testDifferentIdentity() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "orderkey = 1"));
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("SELECT count(*) FROM orders"))).matches("VALUES BIGINT '1'");
    }

    @Test
    public void testRecursion() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "orderkey IN (SELECT local.tiny.orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), RUN_AS_USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "orderkey IN (SELECT orderkey FROM orders)"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching(".*\\QRow filter for 'local.tiny.orders' is recursive\\E.*");
    }

    @Test
    public void testLimitedScope() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "customer"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "orderkey = 1"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT (SELECT min(name) FROM customer WHERE customer.custkey = orders.custkey) FROM orders");
        }).hasMessageMatching("\\Qline 1:31: Invalid row filter for 'local.tiny.customer': Column 'orderkey' cannot be resolved\\E");
    }

    @Test
    public void testSqlInjection() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "nation"), USER, new ViewExpression(USER, Optional.of("local"), Optional.of("tiny"), "regionkey IN (SELECT regionkey FROM region WHERE name = 'ASIA')"));
        ((QueryAssertions.QueryAssert) Assertions.assertThat(this.assertions.query("WITH region(regionkey, name) AS (VALUES (0, 'ASIA'), (1, 'ASIA'), (2, 'ASIA'), (3, 'ASIA'), (4, 'ASIA'))SELECT name FROM nation ORDER BY name LIMIT 1"))).matches("VALUES CAST('CHINA' AS VARCHAR(25))");
    }

    @Test
    public void testInvalidFilter() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "$$$"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:22: Invalid row filter for 'local.tiny.orders': mismatched input '$'. Expecting: <expression>\\E");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "unknown_column"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:22: Invalid row filter for 'local.tiny.orders': Column 'unknown_column' cannot be resolved\\E");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "1"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:22: Expected row filter for 'local.tiny.orders' to be of type BOOLEAN, but was integer\\E");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "count(*) > 0"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:10: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [count(*)]\\E");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "row_number() OVER () > 0"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:22: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [row_number() OVER ()]\\E");
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "grouping(orderkey) = 0"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SELECT count(*) FROM orders");
        }).hasMessageMatching("\\Qline 1:20: Row filter for 'local.tiny.orders' cannot contain aggregations, window functions or grouping operations: [GROUPING (orderkey)]\\E");
    }

    @Test
    public void testShowStats() {
        this.accessControl.reset();
        this.accessControl.rowFilter(new QualifiedObjectName("local", "tiny", "orders"), USER, new ViewExpression(RUN_AS_USER, Optional.of("local"), Optional.of("tiny"), "orderkey = 0"));
        Assertions.assertThatThrownBy(() -> {
            this.assertions.query("SHOW STATS FOR (SELECT * FROM tiny.orders)");
        }).hasMessageMatching("\\QSHOW STATS is not supported for a table with row filtering");
    }
}
