/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.StatsProvider;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.List;
import java.util.Optional;

public class UseNonPartitionedJoinLookupSource
implements Rule<JoinNode> {
    private static final Capture<ExchangeNode> RIGHT_EXCHANGE_NODE = Capture.newCapture();
    private static final Pattern<JoinNode> JOIN_PATTERN = Patterns.join().with(Patterns.Join.right().matching(Patterns.exchange().matching(UseNonPartitionedJoinLookupSource::canBeTranslatedToLocalGather).capturedAs(RIGHT_EXCHANGE_NODE)));

    @Override
    public Pattern<JoinNode> getPattern() {
        return JOIN_PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.getJoinPartitionedBuildMinRowCount(session) > 0L;
    }

    @Override
    public Rule.Result apply(JoinNode node, Captures captures, Rule.Context context) {
        double buildSideRowCount = UseNonPartitionedJoinLookupSource.getSourceTablesRowCount(node.getRight(), context);
        if (Double.isNaN(buildSideRowCount)) {
            return Rule.Result.empty();
        }
        if (buildSideRowCount >= (double)SystemSessionProperties.getJoinPartitionedBuildMinRowCount(context.getSession())) {
            return Rule.Result.empty();
        }
        ExchangeNode rightSideExchange = (ExchangeNode)captures.get(RIGHT_EXCHANGE_NODE);
        ExchangeNode singleThreadedExchange = UseNonPartitionedJoinLookupSource.toGatheringExchange(rightSideExchange);
        return Rule.Result.ofPlanNode(node.replaceChildren((List<PlanNode>)ImmutableList.of((Object)node.getLeft(), (Object)singleThreadedExchange)));
    }

    private static ExchangeNode toGatheringExchange(ExchangeNode exchangeNode) {
        return new ExchangeNode(exchangeNode.getId(), ExchangeNode.Type.GATHER, ExchangeNode.Scope.LOCAL, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, (List<Symbol>)ImmutableList.of()), exchangeNode.getPartitioningScheme().getOutputLayout()), exchangeNode.getSources(), exchangeNode.getInputs(), Optional.empty());
    }

    private static boolean canBeTranslatedToLocalGather(ExchangeNode exchangeNode) {
        return exchangeNode.getScope() == ExchangeNode.Scope.LOCAL && !UseNonPartitionedJoinLookupSource.isSingleGather(exchangeNode) && exchangeNode.getOrderingScheme().isEmpty() && exchangeNode.getPartitioningScheme().getBucketToPartition().isEmpty() && !exchangeNode.getPartitioningScheme().isReplicateNullsAndAny();
    }

    private static boolean isSingleGather(ExchangeNode exchangeNode) {
        return exchangeNode.getType() == ExchangeNode.Type.GATHER && exchangeNode.getPartitioningScheme().getPartitioning().getHandle() == SystemPartitioningHandle.SINGLE_DISTRIBUTION;
    }

    private static double getSourceTablesRowCount(PlanNode node, Rule.Context context) {
        return UseNonPartitionedJoinLookupSource.getSourceTablesRowCount(node, context.getLookup(), context.getStatsProvider());
    }

    @VisibleForTesting
    static double getSourceTablesRowCount(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(JoinNode.class, UnnestNode.class).matches();
        if (hasExpandingNodes) {
            return Double.NaN;
        }
        List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class).findAll();
        return sourceNodes.stream().mapToDouble(sourceNode -> statsProvider.getStats((PlanNode)sourceNode).getOutputRowCount()).sum();
    }
}

