package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IsNull;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.WindowNode;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateUnnest.class */
public class DecorrelateUnnest implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(Booleans.TRUE)).matching(correlatedJoinNode -> {
        return correlatedJoinNode.getType() == JoinType.INNER || correlatedJoinNode.getType() == JoinType.LEFT;
    });
    private final Metadata metadata;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateUnnest$RewriteResult.class */
    public static class RewriteResult {
        PlanNode plan;
        Optional<Symbol> rowNumberSymbol;

        public RewriteResult(PlanNode planNode, Optional<Symbol> optional) {
            this.plan = (PlanNode) Objects.requireNonNull(planNode, "plan is null");
            this.rowNumberSymbol = (Optional) Objects.requireNonNull(optional, "rowNumberSymbol is null");
        }

        public PlanNode getPlan() {
            return this.plan;
        }

        public Optional<Symbol> getRowNumberSymbol() {
            return this.rowNumberSymbol;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateUnnest$Rewriter.class */
    public static class Rewriter extends PlanVisitor<RewriteResult, Void> {
        private final List<Symbol> leftOutputs;
        private final Symbol ordinalitySymbol;
        private final Symbol uniqueSymbol;
        private final PlanNode sequenceSource;
        private final Session session;
        private final Metadata metadata;
        private final Lookup lookup;
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(List<Symbol> list, Symbol symbol, Symbol symbol2, PlanNode planNode, Session session, Metadata metadata, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
            this.leftOutputs = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "leftOutputs is null"));
            this.ordinalitySymbol = (Symbol) Objects.requireNonNull(symbol, "ordinalitySymbol is null");
            this.uniqueSymbol = (Symbol) Objects.requireNonNull(symbol2, "uniqueSymbol is null");
            this.sequenceSource = (PlanNode) Objects.requireNonNull(planNode, "sequenceSource is null");
            this.session = (Session) Objects.requireNonNull(session, "session is null");
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        public static PlanNode rewriteNodeSequence(PlanNode planNode, List<Symbol> list, Symbol symbol, Symbol symbol2, PlanNode planNode2, Session session, Metadata metadata, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
            return new Rewriter(list, symbol, symbol2, planNode2, session, metadata, lookup, planNodeIdAllocator, symbolAllocator).rewrite(planNode).getPlan();
        }

        private RewriteResult rewrite(PlanNode planNode) {
            return (RewriteResult) this.lookup.resolve(planNode).accept(this, null);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitPlan(PlanNode planNode, Void r6) {
            throw new IllegalStateException("Unexpected node type: " + planNode.getClass().getSimpleName());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitUnnest(UnnestNode unnestNode, Void r7) {
            return new RewriteResult(this.sequenceSource, Optional.empty());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitEnforceSingleRow(EnforceSingleRowNode enforceSingleRowNode, Void r12) {
            Symbol newSymbol;
            PlanNode rowNumberNode;
            RewriteResult rewrite = rewrite(enforceSingleRowNode.getSource());
            if (QueryCardinalityUtil.isScalar(rewrite.getPlan(), this.lookup)) {
                return rewrite;
            }
            if (rewrite.getRowNumberSymbol().isPresent()) {
                newSymbol = rewrite.getRowNumberSymbol().get();
                rowNumberNode = rewrite.getPlan();
            } else {
                newSymbol = this.symbolAllocator.newSymbol("row_number", BigintType.BIGINT);
                rowNumberNode = new RowNumberNode(this.idAllocator.getNextId(), rewrite.getPlan(), ImmutableList.of(this.uniqueSymbol), false, newSymbol, Optional.of(2), Optional.empty());
            }
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), rowNumberNode, IrExpressions.ifExpression(new Comparison(Comparison.Operator.GREATER_THAN, newSymbol.toSymbolReference(), new Constant(BigintType.BIGINT, 1L)), new Cast(LogicalPlanner.failFunction(this.metadata, StandardErrorCode.SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"), BooleanType.BOOLEAN), Booleans.TRUE)), Optional.of(newSymbol));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitLimit(LimitNode limitNode, Void r18) {
            Symbol newSymbol;
            PlanNode rowNumberNode;
            RewriteResult rewrite = rewrite(limitNode.getSource());
            if (limitNode.isWithTies()) {
                return new RewriteResult(ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(limitNode, rewrite.getPlan(), this.session, this.metadata, this.idAllocator, this.symbolAllocator, ImmutableList.of(this.uniqueSymbol)), Optional.empty());
            }
            if (rewrite.getRowNumberSymbol().isPresent()) {
                newSymbol = rewrite.getRowNumberSymbol().get();
                rowNumberNode = rewrite.getPlan();
            } else {
                newSymbol = this.symbolAllocator.newSymbol("row_number", BigintType.BIGINT);
                rowNumberNode = new RowNumberNode(this.idAllocator.getNextId(), rewrite.getPlan(), ImmutableList.of(this.uniqueSymbol), false, newSymbol, Optional.empty(), Optional.empty());
            }
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), rowNumberNode, new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, newSymbol.toSymbolReference(), new Constant(BigintType.BIGINT, Long.valueOf(limitNode.getCount())))), Optional.of(newSymbol));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitTopN(TopNNode topNNode, Void r18) {
            RewriteResult rewrite = rewrite(topNNode.getSource());
            Symbol newSymbol = this.symbolAllocator.newSymbol("row_number", BigintType.BIGINT);
            return new RewriteResult(new FilterNode(this.idAllocator.getNextId(), new WindowNode(this.idAllocator.getNextId(), rewrite.getPlan(), new DataOrganizationSpecification(ImmutableList.of(this.uniqueSymbol), Optional.of(topNNode.getOrderingScheme())), ImmutableMap.of(newSymbol, new WindowNode.Function(this.metadata.resolveBuiltinFunction("row_number", ImmutableList.of()), ImmutableList.of(), WindowNode.Frame.DEFAULT_FRAME, false)), Optional.empty(), ImmutableSet.of(), 0), new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, newSymbol.toSymbolReference(), new Constant(BigintType.BIGINT, Long.valueOf(topNNode.getCount())))), Optional.of(newSymbol));
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public RewriteResult visitProject(ProjectNode projectNode, Void r10) {
            RewriteResult rewrite = rewrite(projectNode.getSource());
            Assignments.Builder putIdentity = Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(this.leftOutputs).putIdentity(this.ordinalitySymbol);
            Optional<Symbol> rowNumberSymbol = rewrite.getRowNumberSymbol();
            Objects.requireNonNull(putIdentity);
            rowNumberSymbol.ifPresent(putIdentity::putIdentity);
            return new RewriteResult(new ProjectNode(projectNode.getId(), rewrite.getPlan(), putIdentity.build()), rewrite.getRowNumberSymbol());
        }
    }

    public DecorrelateUnnest(Metadata metadata) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        PlanNode subquery = correlatedJoinNode.getSubquery();
        PlanNodeSearcher searchFrom = PlanNodeSearcher.searchFrom(subquery, context.getLookup());
        Class<EnforceSingleRowNode> cls = EnforceSingleRowNode.class;
        Objects.requireNonNull(EnforceSingleRowNode.class);
        Optional<PlanNode> findFirst = searchFrom.where((v1) -> {
            return r1.isInstance(v1);
        }).recurseOnlyWhen(planNode -> {
            return false;
        }).findFirst();
        if (findFirst.isPresent()) {
            subquery = ((EnforceSingleRowNode) findFirst.get()).getSource();
        }
        Optional<PlanNode> findFirst2 = PlanNodeSearcher.searchFrom(subquery, context.getLookup()).where(planNode2 -> {
            return isSupportedUnnest(planNode2, correlatedJoinNode.getCorrelation(), context.getLookup());
        }).recurseOnlyWhen(planNode3 -> {
            return (planNode3 instanceof ProjectNode) || ((planNode3 instanceof LimitNode) && ((LimitNode) planNode3).getCount() > 0) || ((planNode3 instanceof TopNNode) && ((TopNNode) planNode3).getCount() > 0);
        }).findFirst();
        if (findFirst2.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode) findFirst2.get();
        Symbol newSymbol = context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT);
        PlanNode assignUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), newSymbol);
        PlanNode resolve = context.getLookup().resolve(unnestNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            assignUniqueId = new ProjectNode(projectNode.getId(), assignUniqueId, Assignments.builder().putIdentities(assignUniqueId.getOutputSymbols()).putAll(projectNode.getAssignments()).build());
        }
        JoinType joinType = JoinType.LEFT;
        if (findFirst.isEmpty() && correlatedJoinNode.getType() == JoinType.INNER && unnestNode.getJoinType() == JoinType.INNER) {
            joinType = JoinType.INNER;
        }
        Symbol orElseGet = unnestNode.getOrdinalitySymbol().orElseGet(() -> {
            return context.getSymbolAllocator().newSymbol("ordinality", BigintType.BIGINT);
        });
        UnnestNode unnestNode2 = new UnnestNode(context.getIdAllocator().getNextId(), assignUniqueId, assignUniqueId.getOutputSymbols(), unnestNode.getMappings(), Optional.of(orElseGet), joinType);
        PlanNode rewriteNodeSequence = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), assignUniqueId.getOutputSymbols(), orElseGet, newSymbol, unnestNode2, context.getSession(), this.metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
        if (unnestNode.getJoinType() == JoinType.INNER && unnestNode2.getJoinType() == JoinType.LEFT) {
            Assignments.Builder putIdentities = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
            for (Symbol symbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
                putIdentities.put(symbol, IrExpressions.ifExpression(new IsNull(orElseGet.toSymbolReference()), new Constant(symbol.type(), null), symbol.toSymbolReference()));
            }
            rewriteNodeSequence = new ProjectNode(context.getIdAllocator().getNextId(), rewriteNodeSequence, putIdentities.build());
        }
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewriteNodeSequence, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewriteNodeSequence));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSupportedUnnest(PlanNode planNode, List<Symbol> list, Lookup lookup) {
        if (!(planNode instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode) planNode;
        List list2 = (List) unnestNode.getMappings().stream().map((v0) -> {
            return v0.getInput();
        }).collect(ImmutableList.toImmutableList());
        PlanNode resolve = lookup.resolve(unnestNode.getSource());
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && (ImmutableSet.copyOf(list).containsAll(list2) || ((resolve instanceof ProjectNode) && ImmutableSet.copyOf(list).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) resolve).getAssignments().getExpressions())))) && (unnestNode.getJoinType() == JoinType.INNER || unnestNode.getJoinType() == JoinType.LEFT);
    }
}
