package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.Bind;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter.class */
public final class LambdaCaptureDesugaringRewriter {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter$Context.class */
    public static class Context {
        final LinkedHashSet<Symbol> referencedSymbols;

        public Context() {
            this(new LinkedHashSet());
        }

        private Context(LinkedHashSet<Symbol> linkedHashSet) {
            this.referencedSymbols = linkedHashSet;
        }

        public LinkedHashSet<Symbol> getReferencedSymbols() {
            return this.referencedSymbols;
        }

        public Context withReferencedSymbols(LinkedHashSet<Symbol> linkedHashSet) {
            return new Context(linkedHashSet);
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/LambdaCaptureDesugaringRewriter$Visitor.class */
    private static class Visitor extends ExpressionRewriter<Context> {
        private final SymbolAllocator symbolAllocator;

        public Visitor(SymbolAllocator symbolAllocator) {
            this.symbolAllocator = (SymbolAllocator) Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override // io.trino.sql.ir.ExpressionRewriter
        public Expression rewriteLambda(Lambda lambda, Context context, ExpressionTreeRewriter<Context> expressionTreeRewriter) {
            LinkedHashSet<Symbol> linkedHashSet = new LinkedHashSet<>();
            Expression rewrite = expressionTreeRewriter.rewrite((ExpressionTreeRewriter<Context>) lambda.body(), (Expression) context.withReferencedSymbols(linkedHashSet));
            Sets.SetView<Symbol> difference = Sets.difference(linkedHashSet, ImmutableSet.copyOf(lambda.arguments()));
            ImmutableMap.Builder builder = ImmutableMap.builder();
            ImmutableList.Builder builder2 = ImmutableList.builder();
            for (Symbol symbol : difference) {
                Symbol newSymbol = this.symbolAllocator.newSymbol(symbol.name(), symbol.type());
                builder.put(symbol, newSymbol);
                builder2.add(newSymbol);
            }
            builder2.addAll(lambda.arguments());
            ImmutableMap buildOrThrow = builder.buildOrThrow();
            Lambda lambda2 = new Lambda(builder2.build(), ExpressionSymbolInliner.inlineSymbols((Function<Symbol, Expression>) symbol2 -> {
                return ((Symbol) buildOrThrow.getOrDefault(symbol2, symbol2)).toSymbolReference();
            }, rewrite));
            Expression expression = lambda2;
            if (difference.size() != 0) {
                expression = new Bind((List) difference.stream().map(symbol3 -> {
                    return new Reference(symbol3.type(), symbol3.name());
                }).collect(ImmutableList.toImmutableList()), lambda2);
            }
            context.getReferencedSymbols().addAll(difference);
            return expression;
        }

        @Override // io.trino.sql.ir.ExpressionRewriter
        public Expression rewriteReference(Reference reference, Context context, ExpressionTreeRewriter<Context> expressionTreeRewriter) {
            context.getReferencedSymbols().add(new Symbol(reference.type(), reference.name()));
            return null;
        }
    }

    public static Expression rewrite(Expression expression, SymbolAllocator symbolAllocator) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolAllocator), expression, new Context());
    }

    private LambdaCaptureDesugaringRewriter() {
    }
}
