package io.trino.sql.ir.optimizer.rule;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.Metadata;
import io.trino.operator.scalar.ArrayTransformFunction;
import io.trino.operator.scalar.JsonStringArrayExtractScalar;
import io.trino.operator.scalar.JsonStringToArrayCast;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.optimizer.IrOptimizerRule;
import io.trino.sql.planner.Symbol;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/ir/optimizer/rule/SpecializeTransformWithJsonParse.class */
public class SpecializeTransformWithJsonParse implements IrOptimizerRule {
    private final Metadata metadata;

    public SpecializeTransformWithJsonParse(PlannerContext plannerContext) {
        this.metadata = plannerContext.getMetadata();
    }

    @Override // io.trino.sql.ir.optimizer.IrOptimizerRule
    public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> map) {
        if (expression instanceof Call) {
            Call call = (Call) expression;
            if (call.function().name().getFunctionName().equals(ArrayTransformFunction.ARRAY_TRANSFORM_NAME)) {
                Object first = call.arguments().getFirst();
                if (first instanceof Call) {
                    Call call2 = (Call) first;
                    if (call2.function().name().equals(GlobalFunctionCatalog.builtinFunctionName(JsonStringToArrayCast.JSON_STRING_TO_ARRAY_NAME))) {
                        Object last = call.arguments().getLast();
                        if (last instanceof Lambda) {
                            Expression body = ((Lambda) last).body();
                            if (body instanceof Call) {
                                Call call3 = (Call) body;
                                if (call3.function().name().equals(GlobalFunctionCatalog.builtinFunctionName("json_extract_scalar"))) {
                                    Expression expression2 = (Expression) call2.arguments().getFirst();
                                    Constant constant = (Constant) call3.arguments().getLast();
                                    return Optional.of(new Call(this.metadata.resolveBuiltinFunction(JsonStringArrayExtractScalar.JSON_STRING_ARRAY_EXTRACT_SCALAR_NAME, TypeSignatureProvider.fromTypes((List<? extends Type>) ImmutableList.of(expression2.type(), constant.type()))), ImmutableList.of(expression2, constant)));
                                }
                            }
                        }
                    }
                }
                return Optional.empty();
            }
        }
        return Optional.empty();
    }
}
