package io.prestosql.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.SessionTestUtils;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.DateType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.planner.sanity.TypeValidator;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.FrameBound;
import io.prestosql.sql.tree.WindowFrame;
import io.prestosql.testing.TestingHandles;
import io.prestosql.testing.TestingMetadata;
import java.util.Optional;
import java.util.UUID;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/prestosql/sql/planner/TestTypeValidator.class */
public class TestTypeValidator {
    private static final SqlParser SQL_PARSER = new SqlParser();
    private static final TypeValidator TYPE_VALIDATOR = new TypeValidator();
    private SymbolAllocator symbolAllocator;
    private TableScanNode baseTableScan;
    private Symbol columnA;
    private Symbol columnB;
    private Symbol columnC;
    private Symbol columnD;
    private Symbol columnE;

    @BeforeMethod
    public void setUp() {
        this.symbolAllocator = new SymbolAllocator();
        this.columnA = this.symbolAllocator.newSymbol("a", BigintType.BIGINT);
        this.columnB = this.symbolAllocator.newSymbol("b", IntegerType.INTEGER);
        this.columnC = this.symbolAllocator.newSymbol("c", DoubleType.DOUBLE);
        this.columnD = this.symbolAllocator.newSymbol("d", DateType.DATE);
        this.columnE = this.symbolAllocator.newSymbol("e", VarcharType.createVarcharType(3));
        ImmutableMap build = ImmutableMap.builder().put(this.columnA, new TestingMetadata.TestingColumnHandle("a")).put(this.columnB, new TestingMetadata.TestingColumnHandle("b")).put(this.columnC, new TestingMetadata.TestingColumnHandle("c")).put(this.columnD, new TestingMetadata.TestingColumnHandle("d")).put(this.columnE, new TestingMetadata.TestingColumnHandle("e")).build();
        this.baseTableScan = new TableScanNode(newId(), TestingHandles.TEST_TABLE_HANDLE, ImmutableList.copyOf(build.keySet()), build, TupleDomain.all());
    }

    @Test
    public void testValidProject() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), "bigint");
        Cast cast2 = new Cast(this.columnC.toSymbolReference(), "bigint");
        assertTypesValid(new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(cast2, BigintType.BIGINT), cast2).build()));
    }

    @Test
    public void testValidUnion() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("output", DateType.DATE);
        ImmutableListMultimap build = ImmutableListMultimap.builder().put(newSymbol, this.columnD).put(newSymbol, this.columnD).build();
        assertTypesValid(new UnionNode(newId(), ImmutableList.of(this.baseTableScan, this.baseTableScan), build, ImmutableList.copyOf(build.keySet())));
    }

    @Test
    public void testValidWindow() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE);
        WindowNode.Function function = new WindowNode.Function(new Signature("sum", FunctionKind.WINDOW, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnC.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty()));
        assertTypesValid(new WindowNode(newId(), this.baseTableScan, new WindowNode.Specification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(newSymbol, function), Optional.empty(), ImmutableSet.of(), 0));
    }

    @Test
    public void testValidAggregation() {
        assertTypesValid(new AggregationNode(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new AggregationNode.Aggregation(new Signature("sum", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB)), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()));
    }

    @Test
    public void testValidTypeOnlyCoercion() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), "bigint");
        assertTypesValid(new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(this.columnE.toSymbolReference(), VarcharType.VARCHAR), this.columnE.toSymbolReference()).build()));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'expr(_[0-9]+)?' is expected to be bigint, but the actual type is integer")
    public void testInvalidProject() {
        Cast cast = new Cast(this.columnB.toSymbolReference(), "integer");
        assertTypesValid(new ProjectNode(newId(), this.baseTableScan, Assignments.builder().put(this.symbolAllocator.newSymbol(cast, BigintType.BIGINT), cast).put(this.symbolAllocator.newSymbol(cast, IntegerType.INTEGER), new Cast(this.columnA.toSymbolReference(), "integer")).build()));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
    public void testInvalidAggregationFunctionCall() {
        assertTypesValid(new AggregationNode(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new AggregationNode.Aggregation(new Signature("sum", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnA.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB)), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
    public void testInvalidAggregationFunctionSignature() {
        assertTypesValid(new AggregationNode(newId(), this.baseTableScan, ImmutableMap.of(this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE), new AggregationNode.Aggregation(new Signature("sum", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), BigintType.BIGINT.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.singleGroupingSet(ImmutableList.of(this.columnA, this.columnB)), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
    public void testInvalidWindowFunctionCall() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE);
        WindowNode.Function function = new WindowNode.Function(new Signature("sum", FunctionKind.WINDOW, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnA.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty()));
        assertTypesValid(new WindowNode(newId(), this.baseTableScan, new WindowNode.Specification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(newSymbol, function), Optional.empty(), ImmutableSet.of(), 0));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
    public void testInvalidWindowFunctionSignature() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("sum", DoubleType.DOUBLE);
        WindowNode.Function function = new WindowNode.Function(new Signature("sum", FunctionKind.WINDOW, ImmutableList.of(), ImmutableList.of(), BigintType.BIGINT.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false), ImmutableList.of(this.columnC.toSymbolReference()), new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty()));
        assertTypesValid(new WindowNode(newId(), this.baseTableScan, new WindowNode.Specification(ImmutableList.of(), Optional.empty()), ImmutableMap.of(newSymbol, function), Optional.empty(), ImmutableSet.of(), 0));
    }

    @Test(expectedExceptions = {IllegalArgumentException.class}, expectedExceptionsMessageRegExp = "type of symbol 'output(_[0-9]+)?' is expected to be date, but the actual type is bigint")
    public void testInvalidUnion() {
        Symbol newSymbol = this.symbolAllocator.newSymbol("output", DateType.DATE);
        ImmutableListMultimap build = ImmutableListMultimap.builder().put(newSymbol, this.columnD).put(newSymbol, this.columnA).build();
        assertTypesValid(new UnionNode(newId(), ImmutableList.of(this.baseTableScan, this.baseTableScan), build, ImmutableList.copyOf(build.keySet())));
    }

    private void assertTypesValid(PlanNode planNode) {
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        TYPE_VALIDATOR.validate(planNode, SessionTestUtils.TEST_SESSION, createTestMetadataManager, new TypeAnalyzer(SQL_PARSER, createTestMetadataManager), this.symbolAllocator.getTypes(), WarningCollector.NOOP);
    }

    private static PlanNodeId newId() {
        return new PlanNodeId(UUID.randomUUID().toString());
    }
}
