package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.RowType;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/InlineProjections.class */
public class InlineProjections implements Rule<ProjectNode> {
    private static final Capture<ProjectNode> CHILD = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().with(Patterns.source().matching(Patterns.project().capturedAs(CHILD)));

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        return (Rule.Result) inlineProjections(projectNode, (ProjectNode) captures.get(CHILD)).map((v0) -> {
            return Rule.Result.ofPlanNode(v0);
        }).orElse(Rule.Result.empty());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Optional<ProjectNode> inlineProjections(ProjectNode projectNode, ProjectNode projectNode2) {
        if (projectNode.isIdentity() && projectNode2.isIdentity()) {
            return Optional.of((ProjectNode) projectNode.replaceChildren(ImmutableList.of(projectNode2.getSource())));
        }
        Set<Symbol> extractInliningTargets = extractInliningTargets(projectNode, projectNode2);
        if (extractInliningTargets.isEmpty()) {
            return Optional.empty();
        }
        Assignments assignments = projectNode2.getAssignments();
        Objects.requireNonNull(extractInliningTargets);
        Assignments filter = assignments.filter((v1) -> {
            return r1.contains(v1);
        });
        Assignments.Builder builder = Assignments.builder();
        for (Map.Entry<Symbol, Expression> entry : projectNode.getAssignments().entrySet()) {
            builder.put(entry.getKey(), inlineReferences(entry.getValue(), filter));
        }
        Set set = (Set) projectNode2.getAssignments().entrySet().stream().filter(entry2 -> {
            return extractInliningTargets.contains(entry2.getKey());
        }).map((v0) -> {
            return v0.getValue();
        }).flatMap(expression -> {
            return SymbolsExtractor.extractAll(expression).stream();
        }).collect(Collectors.toSet());
        Assignments.Builder builder2 = Assignments.builder();
        for (Map.Entry<Symbol, Expression> entry3 : projectNode2.getAssignments().entrySet()) {
            if (!extractInliningTargets.contains(entry3.getKey())) {
                builder2.put(entry3);
                if (!isSymbolReference(entry3.getKey(), entry3.getValue())) {
                    set.remove(entry3.getKey());
                }
            }
        }
        Iterator it = set.iterator();
        while (it.hasNext()) {
            builder2.putIdentity((Symbol) it.next());
        }
        Assignments build = builder2.build();
        return Optional.of(new ProjectNode(projectNode.getId(), build.isIdentity() ? projectNode2.getSource() : new ProjectNode(projectNode2.getId(), projectNode2.getSource(), build), builder.build()));
    }

    private static Expression inlineReferences(Expression expression, Assignments assignments) {
        return ExpressionSymbolInliner.inlineSymbols((Function<Symbol, Expression>) symbol -> {
            Expression expression2 = assignments.get(symbol);
            return expression2 != null ? expression2 : symbol.toSymbolReference();
        }, expression);
    }

    private static Set<Symbol> extractInliningTargets(ProjectNode projectNode, ProjectNode projectNode2) {
        ImmutableSet copyOf = ImmutableSet.copyOf(projectNode2.getOutputSymbols());
        Stream<R> flatMap = projectNode.getAssignments().getExpressions().stream().flatMap(expression -> {
            return SymbolsExtractor.extractAll(expression).stream();
        });
        Objects.requireNonNull(copyOf);
        Map map = (Map) flatMap.filter((v1) -> {
            return r1.contains(v1);
        }).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        return Sets.union((Set) map.entrySet().stream().filter(entry -> {
            return ((Long) entry.getValue()).longValue() == 1;
        }).filter(entry2 -> {
            return !projectNode2.getAssignments().isIdentity((Symbol) entry2.getKey());
        }).filter(entry3 -> {
            Expression expression2 = projectNode2.getAssignments().get((Symbol) entry3.getKey());
            return ((expression2 instanceof FieldReference) && (((FieldReference) expression2).base().type() instanceof RowType)) ? false : true;
        }).map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toSet()), (Set) map.keySet().stream().filter(symbol -> {
            return (projectNode2.getAssignments().get(symbol) instanceof Constant) || (projectNode2.getAssignments().get(symbol) instanceof Reference);
        }).filter(symbol2 -> {
            return !projectNode2.getAssignments().isIdentity(symbol2);
        }).collect(Collectors.toSet()));
    }

    private static boolean isSymbolReference(Symbol symbol, Expression expression) {
        return (expression instanceof Reference) && ((Reference) expression).name().equals(symbol.name());
    }
}
