package io.trino.plugin.jdbc;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.base.Verify;
import io.trino.Session;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.query.QueryAssertions;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.QueryRunner;
import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TestTable;
import java.util.Objects;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/plugin/jdbc/BaseAutomaticJoinPushdownTest.class */
public abstract class BaseAutomaticJoinPushdownTest extends AbstractTestQueryFramework {
    @Test
    public void testJoinPushdownWithEmptyStatsInitially() {
        Session joinPushdownAutomatic = joinPushdownAutomatic(getSession());
        TestTable joinTestTable = joinTestTable("left", 2000L, 500);
        try {
            TestTable joinTestTable2 = joinTestTable("right", 1000L, 1000);
            try {
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 50.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                gatherStats(joinTestTable.getName());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 50.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                gatherStats(joinTestTable2.getName());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 50.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isFullyPushedDown();
                if (joinTestTable2 != null) {
                    joinTestTable2.close();
                }
                if (joinTestTable != null) {
                    joinTestTable.close();
                }
            } catch (Throwable th) {
                if (joinTestTable2 != null) {
                    try {
                        joinTestTable2.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (joinTestTable != null) {
                try {
                    joinTestTable.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testCrossJoinNoPushdown() {
        Session joinPushdownAutomatic = joinPushdownAutomatic(getSession());
        TestTable joinTestTable = joinTestTable("left", 1000L, 1);
        try {
            TestTable joinTestTable2 = joinTestTable("right", 100L, 1);
            try {
                gatherStats(joinTestTable.getName());
                gatherStats(joinTestTable2.getName());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 5.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                if (joinTestTable2 != null) {
                    joinTestTable2.close();
                }
                if (joinTestTable != null) {
                    joinTestTable.close();
                }
            } catch (Throwable th) {
                if (joinTestTable2 != null) {
                    try {
                        joinTestTable2.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (joinTestTable != null) {
                try {
                    joinTestTable.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testJoinPushdownAutomatic() {
        Session joinPushdownAutomatic = joinPushdownAutomatic(getSession());
        TestTable joinTestTable = joinTestTable("left", 6000L, 750);
        try {
            TestTable joinTestTable2 = joinTestTable("right", 1000L, 1000);
            try {
                gatherStats(joinTestTable.getName());
                gatherStats(joinTestTable2.getName());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(joinPushdownAutomatic, String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 2.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isFullyPushedDown();
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(Session.builder(maxJoinToTablesRatio(joinPushdownAutomatic, 2.0d)).setCatalogSessionProperty((String) joinPushdownAutomatic.getCatalog().orElseThrow(), "join_pushdown_automatic_max_table_size", "1kB").build(), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(joinPushdownAutomatic, String.format("SELECT l.key FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isFullyPushedDown();
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(maxJoinToTablesRatio(joinPushdownAutomatic, 1.0d), String.format("SELECT * FROM %s l JOIN %s r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isNotFullyPushedDown(joinOverTableScans());
                if (joinTestTable2 != null) {
                    joinTestTable2.close();
                }
                if (joinTestTable != null) {
                    joinTestTable.close();
                }
            } catch (Throwable th) {
                if (joinTestTable2 != null) {
                    try {
                        joinTestTable2.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (joinTestTable != null) {
                try {
                    joinTestTable.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testAutomaticJoinPushdownOverAggregationPushdown() {
        Session joinPushdownAutomatic = joinPushdownAutomatic(getSession());
        TestTable joinTestTable = joinTestTable("left", 1000L, 100);
        try {
            TestTable joinTestTable2 = joinTestTable("right", 100L, 50);
            try {
                gatherStats(joinTestTable.getName());
                gatherStats(joinTestTable2.getName());
                ((QueryAssertions.QueryAssert) Assertions.assertThat(query(joinPushdownAutomatic, String.format("SELECT * FROM %s l JOIN (SELECT DISTINCT key FROM %s) r ON l.key = r.key", joinTestTable.getName(), joinTestTable2.getName())))).isFullyPushedDown();
                if (joinTestTable2 != null) {
                    joinTestTable2.close();
                }
                if (joinTestTable != null) {
                    joinTestTable.close();
                }
            } catch (Throwable th) {
                if (joinTestTable2 != null) {
                    try {
                        joinTestTable2.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (Throwable th3) {
            if (joinTestTable != null) {
                try {
                    joinTestTable.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    @Test
    public void testAutomaticJoinPushdownTwice() {
        Session joinPushdownAutomatic = joinPushdownAutomatic(getSession());
        TestTable joinTestTable = joinTestTable("first", 1000L, 1000);
        try {
            TestTable joinTestTable2 = joinTestTable("second", 1000L, 1000);
            try {
                TestTable joinTestTable3 = joinTestTable("third", 1000L, 1000);
                try {
                    gatherStats(joinTestTable.getName());
                    gatherStats(joinTestTable2.getName());
                    gatherStats(joinTestTable3.getName());
                    ((QueryAssertions.QueryAssert) Assertions.assertThat(query(joinPushdownAutomatic, String.format("SELECT * FROM %s first, %s second, %s third WHERE first.key = second.key AND second.key = third.key AND third.intpadding = 42", joinTestTable.getName(), joinTestTable2.getName(), joinTestTable3.getName())))).isFullyPushedDown();
                    if (joinTestTable3 != null) {
                        joinTestTable3.close();
                    }
                    if (joinTestTable2 != null) {
                        joinTestTable2.close();
                    }
                    if (joinTestTable != null) {
                        joinTestTable.close();
                    }
                } catch (Throwable th) {
                    if (joinTestTable3 != null) {
                        try {
                            joinTestTable3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (Throwable th3) {
                if (joinTestTable2 != null) {
                    try {
                        joinTestTable2.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (Throwable th5) {
            if (joinTestTable != null) {
                try {
                    joinTestTable.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    protected static PlanMatchPattern joinOverTableScans() {
        return PlanMatchPattern.node(JoinNode.class, new PlanMatchPattern[]{PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.node(TableScanNode.class, new PlanMatchPattern[0])})});
    }

    private TestTable joinTestTable(String str, long j, int i) {
        Preconditions.checkArgument(j < ((Long) computeScalar("SELECT count(*) FROM " + "tpch.tiny.orders")).longValue(), "rowsCount too high: %s", j);
        return new TestTable(tableCreator(), str, String.format("(key, padding, intpadding) AS SELECT mod(orderkey, %s), '%s', orderkey FROM %s ORDER BY orderkey LIMIT %s", Integer.valueOf(i), Strings.repeat("x", 50), "tpch.tiny.orders", Long.valueOf(j)));
    }

    protected SqlExecutor tableCreator() {
        QueryRunner queryRunner = getQueryRunner();
        Objects.requireNonNull(queryRunner);
        return queryRunner::execute;
    }

    protected abstract void gatherStats(String str);

    protected Session joinPushdownAutomatic(Session session) {
        return Session.builder(joinPushdownEnabled(session)).setCatalogSessionProperty((String) session.getCatalog().orElseThrow(), "join_pushdown_strategy", "AUTOMATIC").build();
    }

    protected Session joinPushdownEnabled(Session session) {
        Verify.verify(!new JdbcMetadataConfig().isJoinPushdownEnabled());
        return Session.builder(session).setCatalogSessionProperty((String) session.getCatalog().orElseThrow(), "join_pushdown_enabled", "true").build();
    }

    private Session maxJoinToTablesRatio(Session session, double d) {
        return Session.builder(session).setCatalogSessionProperty((String) session.getCatalog().orElseThrow(), "join_pushdown_automatic_max_join_to_tables_ratio", String.valueOf(d)).build();
    }
}
