/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.ir.optimizer.rule;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.InterpretedFunctionInvoker;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Switch;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.IrOptimizerRule;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.Symbol;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class RemoveRedundantSwitchClauses
implements IrOptimizerRule {
    private final Metadata metadata;
    private final InterpretedFunctionInvoker functionInvoker;

    public RemoveRedundantSwitchClauses(PlannerContext context) {
        this.metadata = context.getMetadata();
        this.functionInvoker = new InterpretedFunctionInvoker(context.getFunctionManager());
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings) {
        Object defaultValue;
        if (!(expression instanceof Switch)) return Optional.empty();
        Switch switch_ = (Switch)expression;
        Object object = switch_.operand();
        Expression operand = object;
        Object whenClauses = object = switch_.whenClauses();
        try {
            defaultValue = object = switch_.defaultValue();
        }
        catch (Throwable throwable) {
            throw new MatchException(throwable.toString(), throwable);
        }
        if (!DeterminismEvaluator.isDeterministic(operand)) {
            return Optional.empty();
        }
        ArrayList<WhenClause> newClauses = new ArrayList<WhenClause>();
        Object newDefault = defaultValue;
        ResolvedFunction equals = this.metadata.resolveOperator(OperatorType.EQUAL, (List<? extends Type>)ImmutableList.of((Object)operand.type(), (Object)operand.type()));
        HashSet<Expression> seen = new HashSet<Expression>();
        boolean changed = false;
        Iterator iterator = whenClauses.iterator();
        while (iterator.hasNext()) {
            WhenClause whenClause = (WhenClause)iterator.next();
            Expression candidate = whenClause.getOperand();
            if (seen.contains(candidate)) {
                changed = true;
                continue;
            }
            if (operand.equals(candidate)) {
                changed = true;
                newDefault = whenClause.getResult();
                break;
            }
            if (operand instanceof Constant) {
                Constant constantOperand = (Constant)operand;
                if (candidate instanceof Constant) {
                    Constant constantCandidate = (Constant)candidate;
                    changed = true;
                    if (!Booleans.TRUE.equals(this.functionInvoker.invoke(equals, session.toConnectorSession(), constantOperand.value(), constantCandidate.value()))) continue;
                    newDefault = whenClause.getResult();
                    break;
                }
            }
            newClauses.add(whenClause);
            if (!DeterminismEvaluator.isDeterministic(candidate)) continue;
            seen.add(candidate);
        }
        if (!changed) {
            return Optional.empty();
        }
        if (!newClauses.isEmpty()) return Optional.of(new Switch(operand, newClauses, (Expression)newDefault));
        return Optional.of(newDefault);
    }
}

