/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.testing;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import io.trino.sql.SqlFormatter;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.DefaultTraversalVisitor;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.Statement;
import jakarta.annotation.Nullable;
import java.util.List;

public final class TreeAssertions {
    private TreeAssertions() {
    }

    public static void assertFormattedSql(SqlParser sqlParser, Node expected) {
        String formatted = SqlFormatter.formatSql(expected);
        Statement actual = TreeAssertions.parseFormatted(sqlParser, formatted, expected);
        TreeAssertions.assertEquals(SqlFormatter.formatSql(actual), formatted);
        if (!actual.equals(expected)) {
            TreeAssertions.assertListEquals(TreeAssertions.linearizeTree(actual), TreeAssertions.linearizeTree(expected));
        }
        TreeAssertions.assertEquals(actual, expected);
    }

    private static Statement parseFormatted(SqlParser sqlParser, String sql, Node tree) {
        try {
            return sqlParser.createStatement(sql);
        }
        catch (ParsingException e) {
            String message = "failed to parse formatted SQL: %s\nerror: %s\ntree: %s".formatted(sql, e.getMessage(), tree);
            throw new AssertionError(message, e);
        }
    }

    private static List<Node> linearizeTree(Node tree) {
        final ImmutableList.Builder nodes = ImmutableList.builder();
        new DefaultTraversalVisitor<Void>(){

            @Override
            public Void process(Node node, @Nullable Void context) {
                super.process(node, context);
                nodes.add((Object)node);
                return null;
            }
        }.process(tree, null);
        return nodes.build();
    }

    private static <T> void assertListEquals(List<T> actual, List<T> expected) {
        if (actual.size() != expected.size()) {
            throw new AssertionError((Object)"Lists not equal in size%n%s".formatted(TreeAssertions.formatLists(actual, expected)));
        }
        if (!actual.equals(expected)) {
            throw new AssertionError((Object)"Lists not equal at index %s%n%s".formatted(TreeAssertions.differingIndex(actual, expected), TreeAssertions.formatLists(actual, expected)));
        }
    }

    private static <T> String formatLists(List<T> actual, List<T> expected) {
        Joiner joiner = Joiner.on((String)"\n    ");
        return "Actual [%s]:%n    %s%nExpected [%s]:%n    %s%n".formatted(actual.size(), joiner.join(actual), expected.size(), joiner.join(expected));
    }

    private static <T> int differingIndex(List<T> actual, List<T> expected) {
        for (int i = 0; i < actual.size(); ++i) {
            if (actual.get(i).equals(expected.get(i))) continue;
            return i;
        }
        return actual.size();
    }

    private static <T> void assertEquals(T actual, T expected) {
        if (!actual.equals(expected)) {
            throw new AssertionError((Object)"expected [%s] but found [%s]".formatted(expected, actual));
        }
    }
}

