/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.expressions;

import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class CanonicalRowExpressionRewriter
extends RowExpressionRewriter<Void> {
    private final boolean removeConstants;

    private CanonicalRowExpressionRewriter(boolean removeConstants) {
        this.removeConstants = removeConstants;
    }

    public static RowExpression canonicalizeRowExpression(RowExpression expression, boolean removeConstants) {
        return RowExpressionTreeRewriter.rewriteWith(new CanonicalRowExpressionRewriter(removeConstants), expression, null);
    }

    @Override
    public RowExpression rewriteInputReference(InputReferenceExpression input, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        return input.canonicalize();
    }

    @Override
    public RowExpression rewriteCall(CallExpression call, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        List<RowExpression> arguments = this.rewrite(call.getArguments(), context, treeRewriter);
        if (!CanonicalRowExpressionRewriter.sameElements(call.getArguments(), arguments)) {
            return new CallExpression(Optional.empty(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), arguments);
        }
        return call.canonicalize();
    }

    @Override
    public RowExpression rewriteConstant(ConstantExpression literal, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        if (!this.removeConstants) {
            return literal.canonicalize();
        }
        return new ConstantExpression(null, literal.getType());
    }

    @Override
    public RowExpression rewriteLambda(LambdaDefinitionExpression lambda, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        RowExpression body = treeRewriter.rewrite(lambda.getBody(), context);
        if (body != lambda.getBody()) {
            return new LambdaDefinitionExpression(Optional.empty(), lambda.getArgumentTypes(), lambda.getArguments(), body);
        }
        return lambda.canonicalize();
    }

    @Override
    public RowExpression rewriteVariableReference(VariableReferenceExpression variable, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        return variable.canonicalize();
    }

    @Override
    public RowExpression rewriteSpecialForm(SpecialFormExpression specialForm, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        List<RowExpression> arguments = this.rewrite(specialForm.getArguments(), context, treeRewriter);
        if (!CanonicalRowExpressionRewriter.sameElements(specialForm.getArguments(), arguments)) {
            return new SpecialFormExpression(Optional.empty(), specialForm.getForm(), specialForm.getType(), arguments);
        }
        return specialForm.canonicalize();
    }

    private List<RowExpression> rewrite(List<RowExpression> items, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
        ArrayList<RowExpression> rewrittenExpressions = new ArrayList<RowExpression>();
        for (RowExpression expression : items) {
            rewrittenExpressions.add(treeRewriter.rewrite(expression, context));
        }
        return Collections.unmodifiableList(rewrittenExpressions);
    }

    private static <T> boolean sameElements(Collection<? extends T> a, Collection<? extends T> b) {
        if (a.size() != b.size()) {
            return false;
        }
        Iterator<T> first = a.iterator();
        Iterator<T> second = b.iterator();
        while (first.hasNext() && second.hasNext()) {
            if (first.next() == second.next()) continue;
            return false;
        }
        return true;
    }
}

