package io.prestosql.sql.planner.sanity;

import com.google.common.base.Preconditions;
import com.google.common.collect.ListMultimap;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.planner.SimplePlanVisitor;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.planner.sanity.PlanSanityChecker;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.type.TypeCoercion;
import io.prestosql.type.UnknownType;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/* loaded from: input_file:io/prestosql/sql/planner/sanity/TypeValidator.class */
public final class TypeValidator implements PlanSanityChecker.Checker {

    /* loaded from: input_file:io/prestosql/sql/planner/sanity/TypeValidator$Visitor.class */
    private static class Visitor extends SimplePlanVisitor<Void> {
        private final Session session;
        private final Metadata metadata;
        private final TypeCoercion typeCoercion;
        private final TypeAnalyzer typeAnalyzer;
        private final TypeProvider types;
        private final WarningCollector warningCollector;

        public Visitor(Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider, WarningCollector warningCollector) {
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            metadata.getClass();
            this.typeCoercion = new TypeCoercion(metadata::getType);
            this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.warningCollector = (WarningCollector) Objects.requireNonNull(warningCollector, "warningCollector is null");
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Void visitAggregation(AggregationNode aggregationNode, Void r7) {
            visitPlan((PlanNode) aggregationNode, (AggregationNode) r7);
            AggregationNode.Step step = aggregationNode.getStep();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                Symbol key = entry.getKey();
                AggregationNode.Aggregation value = entry.getValue();
                switch (step) {
                    case SINGLE:
                        checkSignature(key, value.getSignature());
                        checkCall(key, value.getSignature().getName(), value.getArguments());
                        break;
                    case FINAL:
                        checkSignature(key, value.getSignature());
                        break;
                }
            }
            return null;
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Void visitWindow(WindowNode windowNode, Void r6) {
            visitPlan((PlanNode) windowNode, (WindowNode) r6);
            checkWindowFunctions(windowNode.getWindowFunctions());
            return null;
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Void visitProject(ProjectNode projectNode, Void r8) {
            visitPlan((PlanNode) projectNode, (ProjectNode) r8);
            for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
                Type type = this.types.get(entry.getKey());
                if (entry.getValue() instanceof SymbolReference) {
                    verifyTypeSignature(entry.getKey(), type.getTypeSignature(), this.types.get(Symbol.from(entry.getValue())).getTypeSignature());
                } else {
                    verifyTypeSignature(entry.getKey(), type.getTypeSignature(), this.typeAnalyzer.getType(this.session, this.types, entry.getValue()).getTypeSignature());
                }
            }
            return null;
        }

        @Override // io.prestosql.sql.planner.plan.PlanVisitor
        public Void visitUnion(UnionNode unionNode, Void r8) {
            visitPlan((PlanNode) unionNode, (UnionNode) r8);
            ListMultimap<Symbol, Symbol> symbolMapping = unionNode.getSymbolMapping();
            for (Symbol symbol : symbolMapping.keySet()) {
                List list = symbolMapping.get(symbol);
                Type type = this.types.get(symbol);
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    verifyTypeSignature(symbol, type.getTypeSignature(), this.types.get((Symbol) it.next()).getTypeSignature());
                }
            }
            return null;
        }

        private void checkWindowFunctions(Map<Symbol, WindowNode.Function> map) {
            map.forEach((symbol, function) -> {
                checkSignature(symbol, function.getSignature());
                checkCall(symbol, function.getSignature().getName(), function.getArguments());
            });
        }

        private void checkSignature(Symbol symbol, Signature signature) {
            verifyTypeSignature(symbol, this.types.get(symbol).getTypeSignature(), signature.getReturnType());
        }

        private void checkCall(Symbol symbol, FunctionCall functionCall) {
            verifyTypeSignature(symbol, this.types.get(symbol).getTypeSignature(), this.typeAnalyzer.getType(this.session, this.types, functionCall).getTypeSignature());
        }

        private void checkCall(Symbol symbol, String str, List<Expression> list) {
            verifyTypeSignature(symbol, this.types.get(symbol).getTypeSignature(), this.metadata.getType(this.metadata.resolveFunction(QualifiedName.of(str), this.typeAnalyzer.getCallArgumentTypes(this.session, this.types, list)).getReturnType()).getTypeSignature());
        }

        private void verifyTypeSignature(Symbol symbol, TypeSignature typeSignature, TypeSignature typeSignature2) {
            if (typeSignature2.equals(UnknownType.UNKNOWN.getTypeSignature()) || this.typeCoercion.isTypeOnlyCoercion(this.metadata.getType(typeSignature2), this.metadata.getType(typeSignature))) {
                return;
            }
            Preconditions.checkArgument(typeSignature.equals(typeSignature2), "type of symbol '%s' is expected to be %s, but the actual type is %s", symbol, typeSignature, typeSignature2);
        }
    }

    @Override // io.prestosql.sql.planner.sanity.PlanSanityChecker.Checker
    public void validate(PlanNode planNode, Session session, Metadata metadata, TypeAnalyzer typeAnalyzer, TypeProvider typeProvider, WarningCollector warningCollector) {
        planNode.accept(new Visitor(session, metadata, typeAnalyzer, typeProvider, warningCollector), null);
    }
}
