package io.prestosql.cost;

import com.google.common.base.Verify;
import io.airlift.log.Logger;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.GroupReference;
import io.prestosql.sql.planner.iterative.Memo;
import io.prestosql.sql.planner.plan.PlanNode;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/cost/CachingCostProvider.class */
public class CachingCostProvider implements CostProvider {
    private static final Logger log = Logger.get(CachingCostProvider.class);
    private final CostCalculator costCalculator;
    private final StatsProvider statsProvider;
    private final Optional<Memo> memo;
    private final Session session;
    private final TypeProvider types;
    private final Map<PlanNode, PlanCostEstimate> cache;

    public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Session session, TypeProvider typeProvider) {
        this(costCalculator, statsProvider, Optional.empty(), session, typeProvider);
    }

    public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Optional<Memo> optional, Session session, TypeProvider typeProvider) {
        this.cache = new IdentityHashMap();
        this.costCalculator = (CostCalculator) Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.statsProvider = (StatsProvider) Objects.requireNonNull(statsProvider, "statsProvider is null");
        this.memo = (Optional) Objects.requireNonNull(optional, "memo is null");
        this.session = (Session) Objects.requireNonNull(session, "session is null");
        this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
    }

    @Override // io.prestosql.cost.CostProvider
    public PlanCostEstimate getCost(PlanNode planNode) {
        if (!SystemSessionProperties.isEnableStatsCalculator(this.session)) {
            return PlanCostEstimate.unknown();
        }
        Objects.requireNonNull(planNode, "node is null");
        try {
            if (planNode instanceof GroupReference) {
                return getGroupCost((GroupReference) planNode);
            }
            PlanCostEstimate planCostEstimate = this.cache.get(planNode);
            if (planCostEstimate != null) {
                return planCostEstimate;
            }
            PlanCostEstimate calculateCost = calculateCost(planNode);
            Verify.verify(this.cache.put(planNode, calculateCost) == null, "Cost already set", new Object[0]);
            return calculateCost;
        } catch (RuntimeException e) {
            if (!SystemSessionProperties.isIgnoreStatsCalculatorFailures(this.session)) {
                throw e;
            }
            log.error(e, "Error occurred when computing cost for query %s", new Object[]{this.session.getQueryId()});
            return PlanCostEstimate.unknown();
        }
    }

    private PlanCostEstimate getGroupCost(GroupReference groupReference) {
        int groupId = groupReference.getGroupId();
        Memo orElseThrow = this.memo.orElseThrow(() -> {
            return new IllegalStateException("CachingCostProvider without memo cannot handle GroupReferences");
        });
        Optional<PlanCostEstimate> cost = orElseThrow.getCost(groupId);
        if (cost.isPresent()) {
            return cost.get();
        }
        PlanCostEstimate calculateCost = calculateCost(orElseThrow.getNode(groupId));
        Verify.verify(orElseThrow.getCost(groupId).isEmpty(), "Group cost already set", new Object[0]);
        orElseThrow.storeCost(groupId, calculateCost);
        return calculateCost;
    }

    private PlanCostEstimate calculateCost(PlanNode planNode) {
        return this.costCalculator.calculateCost(planNode, this.statsProvider, this, this.session, this.types);
    }
}
