/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule.test;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.MoreCollectors;
import io.trino.Session;
import io.trino.cost.CachingCostProvider;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.CostCalculator;
import io.trino.cost.CostProvider;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.cost.StatsAndCosts;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.cost.TableStatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.matching.Capture;
import io.trino.matching.Match;
import io.trino.matching.Pattern;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.assertions.PlanAssert;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Memo;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.planprinter.PlanPrinter;
import io.trino.testing.PlanTester;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import org.assertj.core.api.Fail;

public class RuleAssert {
    private final Rule<?> rule;
    private final PlanTester planTester;
    private final StatsCalculator statsCalculator;
    private final Session session;
    private final PlanNode plan;
    private final Set<Symbol> symbols;
    private final PlanNodeIdAllocator idAllocator;

    RuleAssert(Rule<?> rule, PlanTester planTester, StatsCalculator statsCalculator, Session session, PlanNodeIdAllocator idAllocator, PlanNode plan, Collection<Symbol> symbols) {
        this.rule = Objects.requireNonNull(rule, "rule is null");
        this.planTester = Objects.requireNonNull(planTester, "planTester is null");
        this.statsCalculator = Objects.requireNonNull(statsCalculator, "statsCalculator is null");
        session.getRequiredTransactionId();
        this.session = session;
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.plan = Objects.requireNonNull(plan, "plan is null");
        this.symbols = ImmutableSet.copyOf(symbols);
    }

    public void doesNotFire() {
        try {
            RuleApplication ruleApplication = this.applyRule();
            if (ruleApplication.wasRuleApplied()) {
                Fail.fail((String)String.format("Expected %s to not fire for:\n\n%s\n\n==>\n\n%s\n", this.rule, PlanPrinter.textLogicalPlan((PlanNode)this.plan, (Metadata)this.planTester.getPlannerContext().getMetadata(), (FunctionManager)this.planTester.getPlannerContext().getFunctionManager(), (StatsAndCosts)StatsAndCosts.empty(), (Session)this.session, (int)2, (boolean)false), PlanPrinter.textLogicalPlan((PlanNode)((PlanNode)ruleApplication.result.getTransformedPlan().get()), (Metadata)this.planTester.getPlannerContext().getMetadata(), (FunctionManager)this.planTester.getPlannerContext().getFunctionManager(), (StatsAndCosts)StatsAndCosts.empty(), (Session)this.session, (int)2, (boolean)false)));
            }
        }
        finally {
            this.planTester.getPlannerContext().getMetadata().cleanupQuery(this.session);
            this.planTester.getTransactionManager().asyncAbort(this.session.getRequiredTransactionId());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void matches(PlanMatchPattern pattern) {
        try {
            PlanNode actual;
            RuleApplication ruleApplication = this.applyRule();
            if (!ruleApplication.wasRuleApplied()) {
                Fail.fail((String)String.format("%s did not fire for:\n%s", this.rule, this.formatPlan(this.plan)));
            }
            if ((actual = ruleApplication.getTransformedPlan()) == this.plan) {
                Fail.fail((String)String.format("%s: rule fired but return the original plan:\n%s\n", this.rule, this.formatPlan(this.plan)));
            }
            if (!ImmutableSet.copyOf((Collection)this.plan.getOutputSymbols()).equals((Object)ImmutableSet.copyOf((Collection)actual.getOutputSymbols()))) {
                Fail.fail((String)String.format("%s: output schema of transformed and original plans are not equivalent\n\texpected: %s\n\tactual:   %s\n", this.rule, this.plan.getOutputSymbols(), actual.getOutputSymbols()));
            }
            PlanAssert.assertPlan(this.session, this.planTester.getPlannerContext().getMetadata(), this.planTester.getPlannerContext().getFunctionManager(), ruleApplication.statsProvider(), new Plan(actual, StatsAndCosts.empty()), ruleApplication.lookup(), pattern);
        }
        finally {
            this.planTester.getPlannerContext().getMetadata().cleanupQuery(this.session);
            this.planTester.getTransactionManager().asyncAbort(this.session.getRequiredTransactionId());
        }
    }

    private RuleApplication applyRule() {
        SymbolAllocator symbolAllocator = new SymbolAllocator(this.symbols);
        Memo memo = new Memo(this.idAllocator, this.plan);
        Lookup lookup = Lookup.from(planNode -> Stream.of(memo.resolve(planNode)));
        PlanNode memoRoot = memo.getNode(memo.getRootGroup());
        return RuleAssert.applyRule(this.rule, memoRoot, this.ruleContext(this.statsCalculator, this.planTester.getEstimatedExchangesCostCalculator(), symbolAllocator, memo, lookup, this.session));
    }

    private static <T> RuleApplication applyRule(Rule<T> rule, PlanNode planNode, Rule.Context context) {
        Capture planNodeCapture = Capture.newCapture();
        Pattern pattern = rule.getPattern().capturedAs(planNodeCapture);
        Optional match = (Optional)pattern.match((Object)planNode, (Object)context.getLookup()).collect(MoreCollectors.toOptional());
        Rule.Result result = !rule.isEnabled(context.getSession()) || match.isEmpty() ? Rule.Result.empty() : rule.apply(((Match)match.get()).capture(planNodeCapture), ((Match)match.get()).captures(), context);
        return new RuleApplication(context.getLookup(), context.getStatsProvider(), result);
    }

    private String formatPlan(PlanNode plan) {
        CachingStatsProvider statsProvider = new CachingStatsProvider(this.statsCalculator, this.session, (TableStatsProvider)new CachingTableStatsProvider(this.planTester.getPlannerContext().getMetadata(), this.session));
        CachingCostProvider costProvider = new CachingCostProvider(this.planTester.getCostCalculator(), (StatsProvider)statsProvider, this.session);
        return PlanPrinter.textLogicalPlan((PlanNode)plan, (Metadata)this.planTester.getPlannerContext().getMetadata(), (FunctionManager)this.planTester.getPlannerContext().getFunctionManager(), (StatsAndCosts)StatsAndCosts.create((PlanNode)plan, (StatsProvider)statsProvider, (CostProvider)costProvider), (Session)this.session, (int)2, (boolean)false);
    }

    private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, final SymbolAllocator symbolAllocator, Memo memo, final Lookup lookup, final Session session) {
        CachingStatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, (TableStatsProvider)new CachingTableStatsProvider(this.planTester.getPlannerContext().getMetadata(), session), RuntimeInfoProvider.noImplementation());
        CachingCostProvider costProvider = new CachingCostProvider(costCalculator, (StatsProvider)statsProvider, Optional.of(memo), session);
        return new Rule.Context(){
            final /* synthetic */ StatsProvider val$statsProvider;
            final /* synthetic */ CostProvider val$costProvider;
            final /* synthetic */ RuleAssert this$0;
            {
                this.val$statsProvider = statsProvider;
                this.val$costProvider = costProvider;
                this.this$0 = this$0;
            }

            public Lookup getLookup() {
                return lookup;
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return this.this$0.idAllocator;
            }

            public SymbolAllocator getSymbolAllocator() {
                return symbolAllocator;
            }

            public Session getSession() {
                return session;
            }

            public StatsProvider getStatsProvider() {
                return this.val$statsProvider;
            }

            public CostProvider getCostProvider() {
                return this.val$costProvider;
            }

            public void checkTimeoutNotExhausted() {
            }

            public WarningCollector getWarningCollector() {
                return WarningCollector.NOOP;
            }
        };
    }

    private record RuleApplication(Lookup lookup, StatsProvider statsProvider, Rule.Result result) {
        private RuleApplication(Lookup lookup, StatsProvider statsProvider, Rule.Result result) {
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.statsProvider = Objects.requireNonNull(statsProvider, "statsProvider is null");
            this.result = Objects.requireNonNull(result, "result is null");
        }

        private boolean wasRuleApplied() {
            return !this.result.isEmpty();
        }

        public PlanNode getTransformedPlan() {
            return (PlanNode)this.result.getTransformedPlan().orElseThrow(() -> new IllegalStateException("Rule did not produce transformed plan"));
        }
    }
}

