package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.PlanNodeDecorrelator;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.CoalesceExpression;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.ExistsPredicate;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.QualifiedName;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/TransformExistsApplyToLateralNode.class */
public class TransformExistsApplyToLateralNode implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private static final QualifiedName COUNT = QualifiedName.of("count");
    private static final FunctionCall COUNT_CALL = new FunctionCall(COUNT, ImmutableList.of());
    private final Signature countSignature;

    public TransformExistsApplyToLateralNode(FunctionRegistry functionRegistry) {
        Objects.requireNonNull(functionRegistry, "functionRegistry is null");
        this.countSignature = functionRegistry.resolveFunction(COUNT, ImmutableList.of());
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
        if (applyNode.getSubqueryAssignments().size() == 1 && (((Expression) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getExpressions())) instanceof ExistsPredicate)) {
            return (Rule.Result) rewriteToNonDefaultAggregation(applyNode, context).map(Rule.Result::ofPlanNode).orElseGet(() -> {
                return Rule.Result.ofPlanNode(rewriteToDefaultAggregation(applyNode, context));
            });
        }
        return Rule.Result.empty();
    }

    private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        Preconditions.checkState(applyNode.getSubquery().getOutputSymbols().isEmpty(), "Expected subquery output symbols to be pruned");
        Symbol symbol = (Symbol) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getSymbols());
        Symbol newSymbol = context.getSymbolAllocator().newSymbol("subqueryTrue", (Type) BooleanType.BOOLEAN);
        Assignments.Builder builder = Assignments.builder();
        builder.putIdentities(applyNode.getInput().getOutputSymbols());
        builder.put(symbol, new CoalesceExpression(ImmutableList.of(newSymbol.toSymbolReference(), BooleanLiteral.FALSE_LITERAL)));
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), new LimitNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), 1L, false), Assignments.of(newSymbol, BooleanLiteral.TRUE_LITERAL));
        return !new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup()).decorrelateFilters(projectNode, applyNode.getCorrelation()).isPresent() ? Optional.empty() : Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new LateralJoinNode(applyNode.getId(), applyNode.getInput(), projectNode, applyNode.getCorrelation(), LateralJoinNode.Type.LEFT, BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery()), builder.build()));
    }

    private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        Symbol newSymbol = context.getSymbolAllocator().newSymbol(COUNT.toString(), (Type) BigintType.BIGINT);
        return new LateralJoinNode(applyNode.getId(), applyNode.getInput(), new ProjectNode(context.getIdAllocator().getNextId(), new AggregationNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), ImmutableMap.of(newSymbol, new AggregationNode.Aggregation(COUNT_CALL, this.countSignature, Optional.empty())), AggregationNode.globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.of((Symbol) Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getSymbols()), new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, newSymbol.toSymbolReference(), new Cast(new LongLiteral("0"), BigintType.BIGINT.toString())))), applyNode.getCorrelation(), LateralJoinNode.Type.INNER, BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery());
    }
}
