package io.prestosql.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.TableHandle;
import io.prestosql.spi.connector.Assignment;
import io.prestosql.spi.connector.ProjectionApplicationResult;
import io.prestosql.spi.expression.ConnectorExpression;
import io.prestosql.sql.planner.ConnectorExpressionTranslator;
import io.prestosql.sql.planner.LiteralEncoder;
import io.prestosql.sql.planner.PartialTranslator;
import io.prestosql.sql.planner.ReferenceAwareExpressionNodeInliner;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.NodeRef;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/sql/planner/iterative/rule/PushProjectionIntoTableScan.class */
public class PushProjectionIntoTableScan implements Rule<ProjectNode> {
    private static final Capture<TableScanNode> TABLE_SCAN = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().with(Patterns.source().matching(Patterns.tableScan().capturedAs(TABLE_SCAN)));
    private final Metadata metadata;
    private final TypeAnalyzer typeAnalyzer;

    public PushProjectionIntoTableScan(Metadata metadata, TypeAnalyzer typeAnalyzer) {
        this.metadata = metadata;
        this.typeAnalyzer = typeAnalyzer;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.prestosql.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        TableScanNode tableScanNode = (TableScanNode) captures.get(TABLE_SCAN);
        Map<Symbol, Expression> map = projectNode.getAssignments().getMap();
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        Iterator<Map.Entry<Symbol, Expression>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            PartialTranslator.extractPartialTranslations(it.next().getValue(), context.getSession(), this.typeAnalyzer, context.getSymbolAllocator().getTypes()).forEach((nodeRef, connectorExpression) -> {
                builder.add(nodeRef);
                builder2.add(connectorExpression);
            });
        }
        ImmutableList build = builder.build();
        List<ConnectorExpression> build2 = builder2.build();
        Optional<ProjectionApplicationResult<TableHandle>> applyProjection = this.metadata.applyProjection(context.getSession(), tableScanNode.getTable(), build2, (Map) tableScanNode.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((Symbol) entry.getKey()).getName();
        }, (v0) -> {
            return v0.getValue();
        })));
        if (applyProjection.isEmpty()) {
            return Rule.Result.empty();
        }
        List projections = applyProjection.get().getProjections();
        Preconditions.checkState(projections.size() == build2.size(), "Mismatch between input and output projections from the connector: expected %s but got %s", build2.size(), projections.size());
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Assignment assignment : applyProjection.get().getAssignments()) {
            Symbol newSymbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
            arrayList.add(newSymbol);
            hashMap.put(newSymbol, assignment.getColumn());
            hashMap2.put(assignment.getVariable(), newSymbol);
        }
        List list = (List) projections.stream().map(connectorExpression2 -> {
            return ConnectorExpressionTranslator.translate(connectorExpression2, hashMap2, new LiteralEncoder(this.metadata));
        }).collect(ImmutableList.toImmutableList());
        ImmutableMap.Builder builder3 = ImmutableMap.builder();
        for (int i = 0; i < build.size(); i++) {
            builder3.put((NodeRef) build.get(i), (Expression) list.get(i));
        }
        ImmutableMap build3 = builder3.build();
        Assignments.Builder builder4 = Assignments.builder();
        projectNode.getAssignments().entrySet().forEach(entry2 -> {
            builder4.put((Symbol) entry2.getKey(), ReferenceAwareExpressionNodeInliner.replaceExpression((Expression) entry2.getValue(), build3));
        });
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), TableScanNode.newInstance(tableScanNode.getId(), (TableHandle) applyProjection.get().getHandle(), arrayList, hashMap), builder4.build()));
    }
}
