package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slices;
import io.trino.metadata.MetadataManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Plugin;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.rowpattern.AggregatedSetDescriptor;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.ClassifierValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexPointer;
import io.trino.sql.planner.rowpattern.MatchNumberValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushDownProjectionsFromPatternRecognition.class */
public class TestPushDownProjectionsFromPatternRecognition extends BaseRuleTest {
    private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution();
    private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction MULTIPLY_BIGINT = FUNCTIONS.resolveOperator(OperatorType.MULTIPLY, ImmutableList.of(BigintType.BIGINT, BigintType.BIGINT));
    private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR, VarcharType.VARCHAR}));
    private static final ResolvedFunction MAX_BY = MetadataManager.createTestMetadataManager().resolveBuiltinFunction("max_by", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT}));
    private static final ResolvedFunction MAX_BY_BIGINT_VARCHAR = MetadataManager.createTestMetadataManager().resolveBuiltinFunction("max_by", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, VarcharType.VARCHAR}));

    public TestPushDownProjectionsFromPatternRecognition() {
        super(new Plugin[0]);
    }

    @Test
    public void testNoAggregations() {
        tester().assertThat(new PushDownProjectionsFromPatternRecognition()).on(planBuilder -> {
            return planBuilder.patternRecognition(patternRecognitionBuilder -> {
                patternRecognitionBuilder.pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), Booleans.TRUE).source(planBuilder.values(planBuilder.symbol("a")));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoNotPushRuntimeEvaluatedArguments() {
        tester().assertThat(new PushDownProjectionsFromPatternRecognition()).on(planBuilder -> {
            return planBuilder.patternRecognition(patternRecognitionBuilder -> {
                patternRecognitionBuilder.pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), new Comparison(Comparison.Operator.GREATER_THAN, new Call(MAX_BY_BIGINT_VARCHAR, ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new Constant(BigintType.BIGINT, 1L), new Reference(BigintType.BIGINT, "match"))), new Call(CONCAT, ImmutableList.of(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("x")), new Reference(VarcharType.VARCHAR, "classifier"))))), new Constant(BigintType.BIGINT, 5L)), ImmutableMap.of(new Symbol(VarcharType.VARCHAR, "classifier"), new ClassifierValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0)), new Symbol(BigintType.BIGINT, "match"), new MatchNumberValuePointer())).source(planBuilder.values(planBuilder.symbol("a")));
            });
        }).doesNotFire();
    }

    @Test
    public void testDoNotPushSymbolReferences() {
        tester().assertThat(new PushDownProjectionsFromPatternRecognition()).on(planBuilder -> {
            return planBuilder.patternRecognition(patternRecognitionBuilder -> {
                patternRecognitionBuilder.pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), new Comparison(Comparison.Operator.GREATER_THAN, new Call(MAX_BY, ImmutableList.of(new Reference(BigintType.BIGINT, "a"), new Reference(BigintType.BIGINT, "b"))), new Constant(BigintType.BIGINT, 5L))).source(planBuilder.values(planBuilder.symbol("a"), planBuilder.symbol("b")));
            });
        }).doesNotFire();
    }

    @Test
    public void testPreProjectArguments() {
        ResolvedFunction resolveBuiltinFunction = tester().getMetadata().resolveBuiltinFunction("max_by", TypeSignatureProvider.fromTypes(new Type[]{BigintType.BIGINT, BigintType.BIGINT}));
        tester().assertThat(new PushDownProjectionsFromPatternRecognition()).on(planBuilder -> {
            return planBuilder.patternRecognition(patternRecognitionBuilder -> {
                patternRecognitionBuilder.pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "agg"), new Constant(BigintType.BIGINT, 5L)), ImmutableMap.of(new Symbol(BigintType.BIGINT, "agg"), new AggregationValuePointer(resolveBuiltinFunction, new AggregatedSetDescriptor(ImmutableSet.of(), true), ImmutableList.of(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "a"), new Constant(BigintType.BIGINT, 1L))), new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 2L)))), Optional.empty(), Optional.empty()))).source(planBuilder.values(planBuilder.symbol("a", BigintType.BIGINT), planBuilder.symbol("b", BigintType.BIGINT)));
            });
        }).matches(PlanMatchPattern.patternRecognition(builder -> {
            builder.pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), new Comparison(Comparison.Operator.LESS_THAN, new Reference(BigintType.BIGINT, "agg"), new Constant(BigintType.BIGINT, 5L)), ImmutableMap.of("agg", new AggregationValuePointer(resolveBuiltinFunction, new AggregatedSetDescriptor(ImmutableSet.of(), true), ImmutableList.of(new Reference(BigintType.BIGINT, "expr_1"), new Reference(BigintType.BIGINT, "expr_2")), Optional.empty(), Optional.empty())));
        }, PlanMatchPattern.project(ImmutableMap.of("expr_1", PlanMatchPattern.expression(new Call(ADD_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "a"), new Constant(BigintType.BIGINT, 1L)))), "expr_2", PlanMatchPattern.expression(new Call(MULTIPLY_BIGINT, ImmutableList.of(new Reference(BigintType.BIGINT, "b"), new Constant(BigintType.BIGINT, 2L)))), "a", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "a")), "b", PlanMatchPattern.expression(new Reference(BigintType.BIGINT, "b"))), PlanMatchPattern.values("a", "b"))));
    }
}
