package io.trino.plugin.jdbc.expression;

import com.google.common.collect.MoreCollectors;
import io.trino.matching.Match;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/plugin/jdbc/expression/TestExpressionMatching.class */
public class TestExpressionMatching {
    private static final Map<String, Set<String>> TYPE_CLASSES = Map.of("integer_class", Set.of("tinyint", "smallint", "integer", "bigint"));

    @Test
    public void testMatchType() {
        DecimalType createDecimalType = DecimalType.createDecimalType(10, 2);
        TypePattern typePattern = typePattern("decimal(p, s)");
        Match match = (Match) typePattern.getPattern().match(createDecimalType).collect(MoreCollectors.onlyElement());
        MatchContext matchContext = new MatchContext();
        typePattern.resolve(match.captures(), matchContext);
        Assertions.assertThat(matchContext.keys()).containsExactlyInAnyOrder(new String[]{"p", "s"});
        Assertions.assertThat(matchContext.get("p")).isEqualTo(10L);
        Assertions.assertThat(matchContext.get("s")).isEqualTo(2L);
    }

    @Test
    public void testExpressionCapture() {
        Call call = new Call(DecimalType.createDecimalType(21, 2), new FunctionName("add"), List.of(new Variable("first", DecimalType.createDecimalType(10, 2)), new Variable("second", BigintType.BIGINT)));
        ExpressionPattern expressionPattern = expressionPattern("foo: decimal(p, s)");
        Match match = (Match) expressionPattern.getPattern().match(call).collect(MoreCollectors.onlyElement());
        MatchContext matchContext = new MatchContext();
        expressionPattern.resolve(match.captures(), matchContext);
        Assertions.assertThat(matchContext.keys()).containsExactlyInAnyOrder(new String[]{"p", "s", "foo"});
        Assertions.assertThat(matchContext.get("p")).isEqualTo(21L);
        Assertions.assertThat(matchContext.get("s")).isEqualTo(2L);
        Assertions.assertThat(matchContext.get("foo")).isSameAs(call);
    }

    @Test
    public void testMatchCall() {
        Call call = new Call(DecimalType.createDecimalType(21, 2), new FunctionName("add"), List.of(new Variable("first", DecimalType.createDecimalType(10, 2)), new Variable("second", BigintType.BIGINT)));
        ExpressionPattern expressionPattern = expressionPattern("add(foo: decimal(p, s), bar: bigint)");
        Match match = (Match) expressionPattern.getPattern().match(call).collect(MoreCollectors.onlyElement());
        MatchContext matchContext = new MatchContext();
        expressionPattern.resolve(match.captures(), matchContext);
        Assertions.assertThat(matchContext.keys()).containsExactlyInAnyOrder(new String[]{"p", "s", "foo", "bar"});
        Assertions.assertThat(matchContext.get("p")).isEqualTo(10L);
        Assertions.assertThat(matchContext.get("s")).isEqualTo(2L);
        Assertions.assertThat(matchContext.get("foo")).isEqualTo(new Variable("first", DecimalType.createDecimalType(10, 2)));
        Assertions.assertThat(matchContext.get("bar")).isEqualTo(new Variable("second", BigintType.BIGINT));
    }

    @Test
    public void testMatchCallWithTypeClass() {
        Call call = new Call(BigintType.BIGINT, new FunctionName("add"), List.of(new Variable("first", IntegerType.INTEGER), new Variable("second", BigintType.BIGINT)));
        ExpressionPattern expressionPattern = expressionPattern("add(foo: integer_class, bar: integer_class): integer_class");
        Match match = (Match) expressionPattern.getPattern().match(call).collect(MoreCollectors.onlyElement());
        MatchContext matchContext = new MatchContext();
        expressionPattern.resolve(match.captures(), matchContext);
        Assertions.assertThat(matchContext.keys()).containsExactlyInAnyOrder(new String[]{"foo", "bar"});
        Assertions.assertThat(matchContext.get("foo")).isEqualTo(new Variable("first", IntegerType.INTEGER));
        Assertions.assertThat(matchContext.get("bar")).isEqualTo(new Variable("second", BigintType.BIGINT));
    }

    private static ExpressionPattern expressionPattern(String str) {
        return new ExpressionMappingParser(TYPE_CLASSES).createExpressionPattern(str);
    }

    private static TypePattern typePattern(String str) {
        return new ExpressionMappingParser(TYPE_CLASSES).createTypePattern(str);
    }
}
