package io.trino.plugin.geospatial;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.PruneSpatialJoinChildrenColumns;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.SpatialJoinNode;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/plugin/geospatial/TestPruneSpatialJoinChildrenColumns.class */
public class TestPruneSpatialJoinChildrenColumns extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin());
    private static final ResolvedFunction TEST_ST_DISTANCE_FUNCTION = FUNCTIONS.resolveFunction("st_distance", TypeSignatureProvider.fromTypes(new Type[]{GeometryType.GEOMETRY, GeometryType.GEOMETRY}));

    public TestPruneSpatialJoinChildrenColumns() {
        super(new Plugin[0]);
    }

    @Test
    public void testPruneOneChild() {
        tester().assertThat(new PruneSpatialJoinChildrenColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a", GeometryType.GEOMETRY);
            Symbol symbol2 = planBuilder.symbol("b", GeometryType.GEOMETRY);
            Symbol symbol3 = planBuilder.symbol("r", DoubleType.DOUBLE);
            return planBuilder.spatialJoin(SpatialJoinNode.Type.INNER, planBuilder.values(new Symbol[]{symbol, planBuilder.symbol("unused", BigintType.BIGINT)}), planBuilder.values(new Symbol[]{symbol2, symbol3}), ImmutableList.of(symbol, symbol2, symbol3), new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(symbol.toSymbolReference(), symbol2.toSymbolReference())), symbol3.toSymbolReference()));
        }).matches(PlanMatchPattern.spatialJoin(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GeometryType.GEOMETRY, "a"), new Reference(GeometryType.GEOMETRY, "b"))), new Reference(DoubleType.DOUBLE, "r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), PlanMatchPattern.strictProject(ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(GeometryType.GEOMETRY, "a"))), PlanMatchPattern.values(new String[]{"a", "unused"})), PlanMatchPattern.values(new String[]{"b", "r"})));
    }

    @Test
    public void testPruneBothChildren() {
        tester().assertThat(new PruneSpatialJoinChildrenColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a", GeometryType.GEOMETRY);
            Symbol symbol2 = planBuilder.symbol("b", GeometryType.GEOMETRY);
            Symbol symbol3 = planBuilder.symbol("r", DoubleType.DOUBLE);
            return planBuilder.spatialJoin(SpatialJoinNode.Type.INNER, planBuilder.values(new Symbol[]{symbol, planBuilder.symbol("unused_left", BigintType.BIGINT)}), planBuilder.values(new Symbol[]{symbol2, symbol3, planBuilder.symbol("unused_right", BigintType.BIGINT)}), ImmutableList.of(symbol, symbol2, symbol3), new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(symbol.toSymbolReference(), symbol2.toSymbolReference())), symbol3.toSymbolReference()));
        }).matches(PlanMatchPattern.spatialJoin(new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GeometryType.GEOMETRY, "a"), new Reference(GeometryType.GEOMETRY, "b"))), new Reference(DoubleType.DOUBLE, "r")), Optional.empty(), Optional.of(ImmutableList.of("a", "b", "r")), PlanMatchPattern.strictProject(ImmutableMap.of("a", PlanMatchPattern.expression(new Reference(GeometryType.GEOMETRY, "a"))), PlanMatchPattern.values(new String[]{"a", "unused_left"})), PlanMatchPattern.strictProject(ImmutableMap.of("b", PlanMatchPattern.expression(new Reference(GeometryType.GEOMETRY, "b")), "r", PlanMatchPattern.expression(new Reference(DoubleType.DOUBLE, "r"))), PlanMatchPattern.values(new String[]{"b", "r", "unused_right"}))));
    }

    @Test
    public void testDoNotPruneOneOutputOrFilterSymbols() {
        tester().assertThat(new PruneSpatialJoinChildrenColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a", GeometryType.GEOMETRY);
            Symbol symbol2 = planBuilder.symbol("b", GeometryType.GEOMETRY);
            Symbol symbol3 = planBuilder.symbol("r", DoubleType.DOUBLE);
            Symbol symbol4 = planBuilder.symbol("output", BigintType.BIGINT);
            return planBuilder.spatialJoin(SpatialJoinNode.Type.INNER, planBuilder.values(new Symbol[]{symbol}), planBuilder.values(new Symbol[]{symbol2, symbol3, symbol4}), ImmutableList.of(symbol4), new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GeometryType.GEOMETRY, "a"), new Reference(GeometryType.GEOMETRY, "b"))), new Reference(DoubleType.DOUBLE, "r")));
        }).doesNotFire();
    }

    @Test
    public void testDoNotPrunePartitionSymbols() {
        tester().assertThat(new PruneSpatialJoinChildrenColumns()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a", GeometryType.GEOMETRY);
            Symbol symbol2 = planBuilder.symbol("b", GeometryType.GEOMETRY);
            Symbol symbol3 = planBuilder.symbol("r", DoubleType.DOUBLE);
            Symbol symbol4 = planBuilder.symbol("left_partition_symbol", BigintType.BIGINT);
            Symbol symbol5 = planBuilder.symbol("right_partition_symbol", BigintType.BIGINT);
            return planBuilder.spatialJoin(SpatialJoinNode.Type.INNER, planBuilder.values(new Symbol[]{symbol, symbol4}), planBuilder.values(new Symbol[]{symbol2, symbol3, symbol5}), ImmutableList.of(symbol, symbol2, symbol3), new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, new Call(TEST_ST_DISTANCE_FUNCTION, ImmutableList.of(new Reference(GeometryType.GEOMETRY, "a"), new Reference(GeometryType.GEOMETRY, "b"))), new Reference(DoubleType.DOUBLE, "r")), Optional.of(symbol4), Optional.of(symbol5), Optional.of("some nice kdb tree"));
        }).doesNotFire();
    }
}
