/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.cost;

import io.prestosql.Session;
import io.prestosql.cost.CostCalculator;
import io.prestosql.cost.CostProvider;
import io.prestosql.cost.LocalCostEstimate;
import io.prestosql.cost.PlanCostEstimate;
import io.prestosql.cost.PlanNodeStatsEstimate;
import io.prestosql.cost.StatsProvider;
import io.prestosql.cost.TaskCountEstimator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.GroupReference;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.UnionNode;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorWithEstimatedExchanges
implements CostCalculator {
    private final CostCalculator costCalculator;
    private final TaskCountEstimator taskCountEstimator;

    @Inject
    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, TaskCountEstimator taskCountEstimator) {
        this.costCalculator = Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session, TypeProvider types) {
        ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(stats, types, this.taskCountEstimator, session);
        PlanCostEstimate costEstimate = this.costCalculator.calculateCost(node, stats, sourcesCosts, session, types);
        LocalCostEstimate estimatedExchangeCost = node.accept(exchangeCostEstimator, null);
        return CostCalculatorWithEstimatedExchanges.addExchangeCost(costEstimate, estimatedExchangeCost);
    }

    private static PlanCostEstimate addExchangeCost(PlanCostEstimate costEstimate, LocalCostEstimate estimatedExchangeCost) {
        return new PlanCostEstimate(costEstimate.getCpuCost() + estimatedExchangeCost.getCpuCost(), costEstimate.getMaxMemory() + estimatedExchangeCost.getMaxMemory(), costEstimate.getMaxMemoryWhenOutputting() + estimatedExchangeCost.getMaxMemory(), costEstimate.getNetworkCost() + estimatedExchangeCost.getNetworkCost(), LocalCostEstimate.addPartialComponents(costEstimate.getRootNodeLocalCostEstimate(), estimatedExchangeCost, new LocalCostEstimate[0]));
    }

    public static LocalCostEstimate calculateRemoteGatherCost(double inputSizeInBytes) {
        return LocalCostEstimate.ofNetwork(inputSizeInBytes);
    }

    public static LocalCostEstimate calculateRemoteRepartitionCost(double inputSizeInBytes) {
        return LocalCostEstimate.of(inputSizeInBytes, 0.0, inputSizeInBytes);
    }

    public static LocalCostEstimate calculateLocalRepartitionCost(double inputSizeInBytes) {
        return LocalCostEstimate.ofCpu(inputSizeInBytes);
    }

    public static LocalCostEstimate calculateRemoteReplicateCost(double inputSizeInBytes, int destinationTaskCount) {
        return LocalCostEstimate.ofNetwork(inputSizeInBytes * (double)destinationTaskCount);
    }

    public static LocalCostEstimate calculateJoinCostWithoutOutput(PlanNode probe, PlanNode build, StatsProvider stats, TypeProvider types, boolean replicated, int estimatedSourceDistributedTaskCount) {
        LocalCostEstimate exchangesCost = CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(probe, build, stats, types, replicated, estimatedSourceDistributedTaskCount);
        LocalCostEstimate inputCost = CostCalculatorWithEstimatedExchanges.calculateJoinInputCost(probe, build, stats, types, replicated, estimatedSourceDistributedTaskCount);
        return LocalCostEstimate.addPartialComponents(exchangesCost, inputCost, new LocalCostEstimate[0]);
    }

    private static LocalCostEstimate calculateJoinExchangeCost(PlanNode probe, PlanNode build, StatsProvider stats, TypeProvider types, boolean replicated, int estimatedSourceDistributedTaskCount) {
        double probeSizeInBytes = stats.getStats(probe).getOutputSizeInBytes(probe.getOutputSymbols(), types);
        double buildSizeInBytes = stats.getStats(build).getOutputSizeInBytes(build.getOutputSymbols(), types);
        if (replicated) {
            LocalCostEstimate replicateCost = CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(buildSizeInBytes, estimatedSourceDistributedTaskCount);
            LocalCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
            return LocalCostEstimate.addPartialComponents(replicateCost, localRepartitionCost, new LocalCostEstimate[0]);
        }
        LocalCostEstimate probeCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(probeSizeInBytes);
        LocalCostEstimate buildRemoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(buildSizeInBytes);
        LocalCostEstimate buildLocalRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
        return LocalCostEstimate.addPartialComponents(probeCost, buildRemoteRepartitionCost, buildLocalRepartitionCost);
    }

    public static LocalCostEstimate calculateJoinInputCost(PlanNode probe, PlanNode build, StatsProvider stats, TypeProvider types, boolean replicated, int estimatedSourceDistributedTaskCount) {
        int buildSizeMultiplier = replicated ? estimatedSourceDistributedTaskCount : 1;
        PlanNodeStatsEstimate probeStats = stats.getStats(probe);
        PlanNodeStatsEstimate buildStats = stats.getStats(build);
        double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputSymbols(), types);
        double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputSymbols(), types);
        double cpuCost = probeSideSize + buildSideSize * (double)buildSizeMultiplier;
        if (replicated) {
            cpuCost += buildSideSize * (double)(buildSizeMultiplier - 1);
        }
        double memoryCost = buildSideSize * (double)buildSizeMultiplier;
        return LocalCostEstimate.of(cpuCost, memoryCost, 0.0);
    }

    private static class ExchangeCostEstimator
    extends PlanVisitor<LocalCostEstimate, Void> {
        private final StatsProvider stats;
        private final TypeProvider types;
        private final TaskCountEstimator taskCountEstimator;
        private final Session session;

        ExchangeCostEstimator(StatsProvider stats, TypeProvider types, TaskCountEstimator taskCountEstimator, Session session) {
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        @Override
        protected LocalCostEstimate visitPlan(PlanNode node, Void context) {
            return LocalCostEstimate.zero();
        }

        @Override
        public LocalCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public LocalCostEstimate visitAggregation(AggregationNode node, Void context) {
            PlanNode source = node.getSource();
            double inputSizeInBytes = this.getStats(source).getOutputSizeInBytes(source.getOutputSymbols(), this.types);
            LocalCostEstimate remoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(inputSizeInBytes);
            LocalCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(inputSizeInBytes);
            return LocalCostEstimate.addPartialComponents(remoteRepartitionCost, localRepartitionCost, new LocalCostEstimate[0]);
        }

        @Override
        public LocalCostEstimate visitJoin(JoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, this.types, Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session));
        }

        @Override
        public LocalCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getSource(), node.getFilteringSource(), this.stats, this.types, Objects.equals(node.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session));
        }

        @Override
        public LocalCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, this.types, node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session));
        }

        @Override
        public LocalCostEstimate visitUnion(UnionNode node, Void context) {
            double inputSizeInBytes = this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), this.types);
            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

