package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/optimizations/AdaptivePartitioning.class */
public class AdaptivePartitioning implements AdaptivePlanOptimizer {
    private static final Logger log = Logger.get(AdaptivePartitioning.class);

    /* loaded from: input_file:io/trino/sql/planner/optimizations/AdaptivePartitioning$Rewriter.class */
    private static class Rewriter extends SimplePlanRewriter<Void> {
        private final int partitionCount;
        private final PlanNodeIdAllocator idAllocator;
        private final RuntimeInfoProvider runtimeInfoProvider;
        private final Set<PlanNodeId> changedPlanIds = new HashSet();

        private Rewriter(int i, PlanNodeIdAllocator planNodeIdAllocator, RuntimeInfoProvider runtimeInfoProvider) {
            this.partitionCount = i;
            this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
            this.runtimeInfoProvider = (RuntimeInfoProvider) Objects.requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitExchange(ExchangeNode exchangeNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            if (exchangeNode.getPartitioningScheme().getPartitioning().getHandle().equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION)) {
                return exchangeNode;
            }
            Stream<PlanNode> stream = exchangeNode.getSources().stream();
            Objects.requireNonNull(rewriteContext);
            List list = (List) stream.map(rewriteContext::rewrite).collect(ImmutableList.toImmutableList());
            PartitioningScheme partitioningScheme = exchangeNode.getPartitioningScheme();
            if (exchangeNode.getScope() == ExchangeNode.Scope.REMOTE && exchangeNode.getPartitioningScheme().getPartitioning().getHandle() == SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) {
                partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(Integer.valueOf(this.partitionCount)));
                this.changedPlanIds.add(exchangeNode.getId());
            }
            return new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), partitioningScheme, list, exchangeNode.getInputs(), exchangeNode.getOrderingScheme());
        }

        @Override // io.trino.sql.planner.plan.PlanVisitor
        public PlanNode visitRemoteSource(RemoteSourceNode remoteSourceNode, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            if (remoteSourceNode.getExchangeType() != ExchangeNode.Type.REPARTITION) {
                return remoteSourceNode;
            }
            Stream<PlanFragmentId> stream = remoteSourceNode.getSourceFragmentIds().stream();
            RuntimeInfoProvider runtimeInfoProvider = this.runtimeInfoProvider;
            Objects.requireNonNull(runtimeInfoProvider);
            Optional findFirst = stream.map(runtimeInfoProvider::getPlanFragment).map((v0) -> {
                return v0.getOutputPartitioningScheme();
            }).filter(partitioningScheme -> {
                return AdaptivePartitioning.isPartitioned(partitioningScheme.getPartitioning().getHandle());
            }).findFirst();
            if (findFirst.isEmpty()) {
                return remoteSourceNode;
            }
            PartitioningScheme withPartitioningHandle = ((PartitioningScheme) findFirst.get()).withPartitionCount(Optional.of(Integer.valueOf(this.partitionCount))).withPartitioningHandle(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION);
            PlanNodeId nextId = this.idAllocator.getNextId();
            this.changedPlanIds.add(nextId);
            return new ExchangeNode(nextId, ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.REMOTE, withPartitioningHandle, ImmutableList.of(remoteSourceNode), ImmutableList.of(remoteSourceNode.getOutputSymbols()), remoteSourceNode.getOrderingScheme());
        }

        public Set<PlanNodeId> getChangedPlanIds() {
            return ImmutableSet.copyOf(this.changedPlanIds);
        }
    }

    @Override // io.trino.sql.planner.optimizations.AdaptivePlanOptimizer
    public AdaptivePlanOptimizer.Result optimizeAndMarkPlanChanges(PlanNode planNode, PlanOptimizer.Context context) {
        if (!SystemSessionProperties.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(context.session())) {
            return new AdaptivePlanOptimizer.Result(planNode, ImmutableSet.of());
        }
        int faultTolerantExecutionMaxPartitionCount = SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount(context.session());
        int faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount = SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(context.session());
        long bytes = SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(context.session()).toBytes();
        RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();
        List<PlanFragment> allPlanFragments = runtimeInfoProvider.getAllPlanFragments();
        if (allPlanFragments.stream().anyMatch(planFragment -> {
            return planFragment.getPartitionCount().orElse(Integer.valueOf(faultTolerantExecutionMaxPartitionCount)).intValue() >= faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount;
        })) {
            return new AdaptivePlanOptimizer.Result(planNode, ImmutableSet.of());
        }
        for (PlanFragment planFragment2 : allPlanFragments) {
            if (consumesHashPartitionedInput(planFragment2) && !runtimeInfoProvider.getRuntimeOutputStats(planFragment2.getId()).isAccurate()) {
                int intValue = planFragment2.getPartitionCount().orElse(Integer.valueOf(faultTolerantExecutionMaxPartitionCount)).intValue();
                List list = (List) planFragment2.getRemoteSourceNodes().stream().filter(remoteSourceNode -> {
                    return remoteSourceNode.getExchangeType() != ExchangeNode.Type.REPLICATE;
                }).map(remoteSourceNode2 -> {
                    return Long.valueOf(remoteSourceNode2.getSourceFragmentIds().stream().mapToLong(planFragmentId -> {
                        return runtimeInfoProvider.getRuntimeOutputStats(planFragmentId).outputDataSizeEstimate().getTotalSizeInBytes();
                    }).sum());
                }).collect(ImmutableList.toImmutableList());
                long longValue = list.size() == 1 ? ((Long) list.get(0)).longValue() : list.stream().mapToLong((v0) -> {
                    return v0.longValue();
                }).sum() - ((Long) Collections.min(list)).longValue();
                if (longValue > bytes * intValue) {
                    log.info("Stage %s has an estimated memory consumption of %s, changing partition count from %s to %s", new Object[]{planFragment2.getId(), DataSize.succinctBytes(longValue), Integer.valueOf(intValue), Integer.valueOf(faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount)});
                    Rewriter rewriter = new Rewriter(faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount, context.idAllocator(), runtimeInfoProvider);
                    return new AdaptivePlanOptimizer.Result(SimplePlanRewriter.rewriteWith(rewriter, planNode), rewriter.getChangedPlanIds());
                }
            }
        }
        return new AdaptivePlanOptimizer.Result(planNode, ImmutableSet.of());
    }

    public static boolean consumesHashPartitionedInput(PlanFragment planFragment) {
        return isPartitioned(planFragment.getPartitioning());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isPartitioned(PartitioningHandle partitioningHandle) {
        return partitioningHandle.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION);
    }
}
