package io.prestosql.sql.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import io.airlift.slice.Slice;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.function.OperatorType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.DesugarArrayConstructorRewriter;
import io.prestosql.sql.planner.DesugarLikeRewriter;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.relational.CallExpression;
import io.prestosql.sql.relational.ConstantExpression;
import io.prestosql.sql.relational.InputReferenceExpression;
import io.prestosql.sql.relational.LambdaDefinitionExpression;
import io.prestosql.sql.relational.RowExpression;
import io.prestosql.sql.relational.RowExpressionVisitor;
import io.prestosql.sql.relational.SpecialForm;
import io.prestosql.sql.relational.SqlToRowExpressionTranslator;
import io.prestosql.sql.relational.VariableReferenceExpression;
import io.prestosql.sql.tree.Expression;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;

/* loaded from: input_file:io/prestosql/sql/planner/optimizations/ExpressionEquivalence.class */
public class ExpressionEquivalence {
    private static final Ordering<RowExpression> ROW_EXPRESSION_ORDERING = Ordering.from(new RowExpressionComparator());
    private final Metadata metadata;
    private final TypeAnalyzer typeAnalyzer;
    private final CanonicalizationVisitor canonicalizationVisitor;

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/ExpressionEquivalence$CanonicalizationVisitor.class */
    private static class CanonicalizationVisitor implements RowExpressionVisitor<RowExpression, Void> {
        private final Metadata metadata;

        public CanonicalizationVisitor(Metadata metadata) {
            this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitCall(CallExpression callExpression, Void r10) {
            CallExpression callExpression2 = new CallExpression(callExpression.getResolvedFunction(), callExpression.getType(), (List) callExpression.getArguments().stream().map(rowExpression -> {
                return (RowExpression) rowExpression.accept(this, r10);
            }).collect(ImmutableList.toImmutableList()));
            String name = callExpression2.getResolvedFunction().getSignature().getName();
            if (name.equals(Signature.mangleOperatorName(OperatorType.EQUAL)) || name.equals(Signature.mangleOperatorName(OperatorType.NOT_EQUAL)) || name.equals(Signature.mangleOperatorName(OperatorType.IS_DISTINCT_FROM))) {
                return new CallExpression(callExpression2.getResolvedFunction(), callExpression2.getType(), ExpressionEquivalence.ROW_EXPRESSION_ORDERING.sortedCopy(callExpression2.getArguments()));
            }
            if (!name.equals(Signature.mangleOperatorName(OperatorType.GREATER_THAN)) && !name.equals(Signature.mangleOperatorName(OperatorType.GREATER_THAN_OR_EQUAL))) {
                return callExpression2;
            }
            Metadata metadata = this.metadata;
            OperatorType operatorType = name.equals(Signature.mangleOperatorName(OperatorType.GREATER_THAN)) ? OperatorType.LESS_THAN : OperatorType.LESS_THAN_OR_EQUAL;
            Stream stream = ExpressionEquivalence.swapPair(callExpression2.getResolvedFunction().getSignature().getArgumentTypes()).stream();
            Metadata metadata2 = this.metadata;
            Objects.requireNonNull(metadata2);
            return new CallExpression(metadata.resolveOperator(operatorType, (List) stream.map(metadata2::getType).collect(ImmutableList.toImmutableList())), callExpression2.getType(), ExpressionEquivalence.swapPair(callExpression2.getArguments()));
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitSpecialForm(SpecialForm specialForm, Void r10) {
            SpecialForm specialForm2 = new SpecialForm(specialForm.getForm(), specialForm.getType(), (List<RowExpression>) specialForm.getArguments().stream().map(rowExpression -> {
                return (RowExpression) rowExpression.accept(this, r10);
            }).collect(ImmutableList.toImmutableList()));
            if (specialForm2.getForm() != SpecialForm.Form.AND && specialForm2.getForm() != SpecialForm.Form.OR) {
                return specialForm2;
            }
            ImmutableSet copyOf = ImmutableSet.copyOf(flattenNestedCallArgs(specialForm2));
            if (copyOf.size() == 1) {
                return (RowExpression) Iterables.getOnlyElement(copyOf);
            }
            return new SpecialForm(specialForm2.getForm(), (Type) BooleanType.BOOLEAN, (List<RowExpression>) ExpressionEquivalence.ROW_EXPRESSION_ORDERING.sortedCopy(copyOf));
        }

        public static List<RowExpression> flattenNestedCallArgs(SpecialForm specialForm) {
            SpecialForm.Form form = specialForm.getForm();
            ImmutableList.Builder builder = ImmutableList.builder();
            for (RowExpression rowExpression : specialForm.getArguments()) {
                if ((rowExpression instanceof SpecialForm) && form == ((SpecialForm) rowExpression).getForm()) {
                    builder.addAll(flattenNestedCallArgs((SpecialForm) rowExpression));
                } else {
                    builder.add(rowExpression);
                }
            }
            return builder.build();
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitConstant(ConstantExpression constantExpression, Void r4) {
            return constantExpression;
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitInputReference(InputReferenceExpression inputReferenceExpression, Void r4) {
            return inputReferenceExpression;
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Void r10) {
            return new LambdaDefinitionExpression(lambdaDefinitionExpression.getArgumentTypes(), lambdaDefinitionExpression.getArguments(), (RowExpression) lambdaDefinitionExpression.getBody().accept(this, r10));
        }

        @Override // io.prestosql.sql.relational.RowExpressionVisitor
        public RowExpression visitVariableReference(VariableReferenceExpression variableReferenceExpression, Void r4) {
            return variableReferenceExpression;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/ExpressionEquivalence$ListComparator.class */
    public static class ListComparator<T> implements Comparator<List<T>> {
        private final Comparator<T> elementComparator;

        public ListComparator(Comparator<T> comparator) {
            this.elementComparator = (Comparator) Objects.requireNonNull(comparator, "elementComparator is null");
        }

        @Override // java.util.Comparator
        public int compare(List<T> list, List<T> list2) {
            int min = Integer.min(list.size(), list2.size());
            for (int i = 0; i < min; i++) {
                int compare = this.elementComparator.compare(list.get(i), list2.get(i));
                if (compare != 0) {
                    return compare;
                }
            }
            return Integer.compare(list.size(), list2.size());
        }
    }

    /* loaded from: input_file:io/prestosql/sql/planner/optimizations/ExpressionEquivalence$RowExpressionComparator.class */
    private static class RowExpressionComparator implements Comparator<RowExpression> {
        private final Comparator<Object> classComparator = Ordering.arbitrary();
        private final ListComparator<RowExpression> argumentComparator = new ListComparator<>(this);

        private RowExpressionComparator() {
        }

        @Override // java.util.Comparator
        public int compare(RowExpression rowExpression, RowExpression rowExpression2) {
            int compare = this.classComparator.compare(rowExpression.getClass(), rowExpression2.getClass());
            if (compare != 0) {
                return compare;
            }
            if (rowExpression instanceof CallExpression) {
                CallExpression callExpression = (CallExpression) rowExpression;
                CallExpression callExpression2 = (CallExpression) rowExpression2;
                return ComparisonChain.start().compare(callExpression.getResolvedFunction().toString(), callExpression2.getResolvedFunction().toString()).compare(callExpression.getArguments(), callExpression2.getArguments(), this.argumentComparator).result();
            }
            if (rowExpression instanceof SpecialForm) {
                SpecialForm specialForm = (SpecialForm) rowExpression;
                SpecialForm specialForm2 = (SpecialForm) rowExpression2;
                return ComparisonChain.start().compare(specialForm.getForm(), specialForm2.getForm()).compare(specialForm.getArguments(), specialForm2.getArguments(), this.argumentComparator).result();
            }
            if (!(rowExpression instanceof ConstantExpression)) {
                if (rowExpression instanceof InputReferenceExpression) {
                    return Integer.compare(((InputReferenceExpression) rowExpression).getField(), ((InputReferenceExpression) rowExpression2).getField());
                }
                if (rowExpression instanceof LambdaDefinitionExpression) {
                    LambdaDefinitionExpression lambdaDefinitionExpression = (LambdaDefinitionExpression) rowExpression;
                    LambdaDefinitionExpression lambdaDefinitionExpression2 = (LambdaDefinitionExpression) rowExpression2;
                    return ComparisonChain.start().compare(lambdaDefinitionExpression.getArgumentTypes(), lambdaDefinitionExpression2.getArgumentTypes(), new ListComparator(Comparator.comparing((v0) -> {
                        return v0.toString();
                    }))).compare(lambdaDefinitionExpression.getArguments(), lambdaDefinitionExpression2.getArguments(), new ListComparator(Comparator.naturalOrder())).compare(lambdaDefinitionExpression.getBody(), lambdaDefinitionExpression2.getBody(), this).result();
                }
                if (!(rowExpression instanceof VariableReferenceExpression)) {
                    throw new IllegalArgumentException("Unsupported RowExpression type " + rowExpression.getClass().getSimpleName());
                }
                VariableReferenceExpression variableReferenceExpression = (VariableReferenceExpression) rowExpression;
                VariableReferenceExpression variableReferenceExpression2 = (VariableReferenceExpression) rowExpression2;
                return ComparisonChain.start().compare(variableReferenceExpression.getName(), variableReferenceExpression2.getName()).compare(variableReferenceExpression.getType(), variableReferenceExpression2.getType(), Comparator.comparing((v0) -> {
                    return v0.toString();
                })).result();
            }
            ConstantExpression constantExpression = (ConstantExpression) rowExpression;
            ConstantExpression constantExpression2 = (ConstantExpression) rowExpression2;
            int compareTo = constantExpression.getType().getTypeSignature().toString().compareTo(rowExpression2.getType().getTypeSignature().toString());
            if (compareTo != 0) {
                return compareTo;
            }
            Object value = constantExpression.getValue();
            Object value2 = constantExpression2.getValue();
            if (value == null) {
                return value2 == null ? 0 : -1;
            }
            if (value2 == null) {
                return 1;
            }
            Class javaType = constantExpression.getType().getJavaType();
            if (javaType == Boolean.TYPE) {
                return ((Boolean) value).compareTo((Boolean) value2);
            }
            if (javaType == Byte.TYPE || javaType == Short.TYPE || javaType == Integer.TYPE || javaType == Long.TYPE) {
                return Long.compare(((Number) value).longValue(), ((Number) value2).longValue());
            }
            if (javaType == Float.TYPE || javaType == Double.TYPE) {
                return Double.compare(((Number) value).doubleValue(), ((Number) value2).doubleValue());
            }
            if (javaType == Slice.class) {
                return ((Slice) value).compareTo((Slice) value2);
            }
            return -1;
        }
    }

    public ExpressionEquivalence(Metadata metadata, TypeAnalyzer typeAnalyzer) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        this.canonicalizationVisitor = new CanonicalizationVisitor(metadata);
    }

    public boolean areExpressionsEquivalent(Session session, Expression expression, Expression expression2, TypeProvider typeProvider) {
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator<Map.Entry<Symbol, Type>> it = typeProvider.allTypes().entrySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next().getKey(), Integer.valueOf(i));
            i++;
        }
        return ((RowExpression) toRowExpression(session, expression, hashMap, typeProvider).accept(this.canonicalizationVisitor, null)).equals((RowExpression) toRowExpression(session, expression2, hashMap, typeProvider).accept(this.canonicalizationVisitor, null));
    }

    private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> map, TypeProvider typeProvider) {
        Expression rewrite = DesugarArrayConstructorRewriter.rewrite(DesugarLikeRewriter.rewrite(expression, session, this.metadata, this.typeAnalyzer, typeProvider), session, this.metadata, this.typeAnalyzer, typeProvider);
        return SqlToRowExpressionTranslator.translate(rewrite, this.typeAnalyzer.getTypes(session, typeProvider, rewrite), map, this.metadata, session, false);
    }

    private static <T> List<T> swapPair(List<T> list) {
        Objects.requireNonNull(list, "pair is null");
        Preconditions.checkArgument(list.size() == 2, "Expected pair to have two elements");
        return ImmutableList.of(list.get(1), list.get(0));
    }
}
