package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId;
import org.apache.iotdb.db.queryengine.plan.relational.planner.OrderingScheme;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionTreeRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;

/* loaded from: input_file:org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper.class */
public class SymbolMapper {
    private final Function<Symbol, Symbol> mappingFunction;

    /* loaded from: input_file:org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper$Builder.class */
    public static class Builder {
        private final ImmutableMap.Builder<Symbol, Symbol> mappings = ImmutableMap.builder();

        public void put(Symbol symbol, Symbol symbol2) {
            this.mappings.put(symbol, symbol2);
        }

        public SymbolMapper build() {
            return SymbolMapper.symbolMapper(this.mappings.buildOrThrow());
        }
    }

    public SymbolMapper(Function<Symbol, Symbol> function) {
        this.mappingFunction = (Function) Objects.requireNonNull(function, "mappingFunction is null");
    }

    public static SymbolMapper symbolMapper(Map<Symbol, Symbol> map) {
        return new SymbolMapper(symbol -> {
            while (map.containsKey(symbol) && !((Symbol) map.get(symbol)).equals(symbol)) {
                symbol = (Symbol) map.get(symbol);
            }
            return symbol;
        });
    }

    public static SymbolMapper symbolReallocator(Map<Symbol, Symbol> map, SymbolAllocator symbolAllocator) {
        return new SymbolMapper(symbol -> {
            if (!map.containsKey(symbol)) {
                Symbol newSymbol = symbolAllocator.newSymbol(symbol);
                map.put(symbol, newSymbol);
                map.put(newSymbol, newSymbol);
                return newSymbol;
            }
            while (map.containsKey(symbol) && !((Symbol) map.get(symbol)).equals(symbol)) {
                symbol = (Symbol) map.get(symbol);
            }
            map.put(symbol, symbol);
            return symbol;
        });
    }

    public Symbol map(Symbol symbol) {
        return this.mappingFunction.apply(symbol);
    }

    public ApplyNode.SetExpression map(ApplyNode.SetExpression setExpression) {
        if (setExpression instanceof ApplyNode.Exists) {
            return setExpression;
        }
        if (setExpression instanceof ApplyNode.In) {
            ApplyNode.In in = (ApplyNode.In) setExpression;
            return new ApplyNode.In(map(in.getValue()), map(in.getReference()));
        }
        if (!(setExpression instanceof ApplyNode.QuantifiedComparison)) {
            throw new IllegalArgumentException("Unexpected value: " + setExpression);
        }
        ApplyNode.QuantifiedComparison quantifiedComparison = (ApplyNode.QuantifiedComparison) setExpression;
        return new ApplyNode.QuantifiedComparison(quantifiedComparison.getOperator(), quantifiedComparison.getQuantifier(), map(quantifiedComparison.getValue()), map(quantifiedComparison.getReference()));
    }

    public List<Symbol> map(List<Symbol> list) {
        return (List) list.stream().map(this::map).collect(ImmutableList.toImmutableList());
    }

    public List<Symbol> mapAndDistinct(List<Symbol> list) {
        return (List) list.stream().map(this::map).distinct().collect(ImmutableList.toImmutableList());
    }

    public Expression map(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() { // from class: org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.SymbolMapper.1
            @Override // org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionRewriter
            public Expression rewriteSymbolReference(SymbolReference symbolReference, Void r5, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                return SymbolMapper.this.map(Symbol.from(symbolReference)).toSymbolReference();
            }
        }, expression);
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode) {
        return map(aggregationNode, planNode, aggregationNode.getPlanNodeId());
    }

    public AggregationNode map(AggregationNode aggregationNode, PlanNode planNode, PlanNodeId planNodeId) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            builder.put(map(entry.getKey()), map(entry.getValue()));
        }
        return new AggregationNode(planNodeId, planNode, builder.buildOrThrow(), AggregationNode.groupingSets(mapAndDistinct(aggregationNode.getGroupingKeys()), aggregationNode.getGroupingSetCount(), aggregationNode.getGlobalGroupingSets()), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol().map(this::map), aggregationNode.getGroupIdSymbol().map(this::map));
    }

    public AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List) aggregation.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList()), aggregation.isDistinct(), aggregation.getFilter().map(this::map), aggregation.getOrderingScheme().map(this::map), aggregation.getMask().map(this::map));
    }

    public LimitNode map(LimitNode limitNode, PlanNode planNode) {
        return new LimitNode(limitNode.getPlanNodeId(), planNode, limitNode.getCount(), limitNode.getTiesResolvingScheme().map(this::map));
    }

    public OrderingScheme map(OrderingScheme orderingScheme) {
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        HashSet hashSet = new HashSet(orderingScheme.getOrderBy().size());
        for (Symbol symbol : orderingScheme.getOrderBy()) {
            Symbol map = map(symbol);
            if (hashSet.add(map)) {
                builder.add(map);
                builder2.put(map, orderingScheme.getOrdering(symbol));
            }
        }
        return new OrderingScheme(builder.build(), builder2.buildOrThrow());
    }

    public TopKNode map(TopKNode topKNode, List<PlanNode> list) {
        return map(topKNode, list, topKNode.getPlanNodeId());
    }

    public TopKNode map(TopKNode topKNode, List<PlanNode> list, PlanNodeId planNodeId) {
        return new TopKNode(planNodeId, list, map(topKNode.getOrderingScheme()), topKNode.getCount(), (List) topKNode.getOutputSymbols().stream().map(this::map).collect(Collectors.toList()), topKNode.isChildrenDataInOrder());
    }

    public static Builder builder() {
        return new Builder();
    }
}
