package io.trino.sql.ir.optimizer.rule;

import io.trino.Session;
import io.trino.sql.ir.Bind;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.optimizer.IrOptimizerRule;
import io.trino.sql.planner.Symbol;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/ir/optimizer/rule/EvaluateBind.class */
public class EvaluateBind implements IrOptimizerRule {
    @Override // io.trino.sql.ir.optimizer.IrOptimizerRule
    public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> map) {
        if (!(expression instanceof Bind)) {
            return Optional.empty();
        }
        Bind bind = (Bind) expression;
        Stream<Expression> stream = bind.values().stream();
        Class<Constant> cls = Constant.class;
        Objects.requireNonNull(Constant.class);
        if (stream.noneMatch((v1) -> {
            return r1.isInstance(v1);
        })) {
            return Optional.empty();
        }
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < bind.values().size(); i++) {
            Symbol symbol = bind.function().arguments().get(i);
            Expression expression2 = bind.values().get(i);
            if (expression2 instanceof Constant) {
                hashMap.put(symbol.name(), (Constant) expression2);
            } else {
                arrayList.add(symbol);
                arrayList2.add(expression2);
            }
        }
        for (int size = bind.values().size(); size < bind.function().arguments().size(); size++) {
            arrayList.add(bind.function().arguments().get(size));
        }
        Optional<Expression> substituteBindings = substituteBindings(bind.function().body(), hashMap);
        return substituteBindings.isEmpty() ? Optional.empty() : arrayList2.isEmpty() ? Optional.of(new Lambda(arrayList, substituteBindings.orElseThrow())) : Optional.of(new Bind(arrayList2, new Lambda(arrayList, substituteBindings.orElseThrow())));
    }

    private Optional<Expression> substituteBindings(Expression expression, final Map<String, Constant> map) {
        Expression rewrite = new ExpressionTreeRewriter(new ExpressionRewriter<Void>(this) { // from class: io.trino.sql.ir.optimizer.rule.EvaluateBind.1
            @Override // io.trino.sql.ir.ExpressionRewriter
            public Expression rewriteReference(Reference reference, Void r5, ExpressionTreeRewriter<Void> expressionTreeRewriter) {
                Constant constant = (Constant) map.get(reference.name());
                return constant == null ? reference : constant;
            }
        }).rewrite((ExpressionTreeRewriter) expression, (Expression) null);
        return rewrite != expression ? Optional.of(rewrite) : Optional.empty();
    }
}
