package io.prestosql.spi.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.RowType;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.parser.SqlParser;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.tree.ArithmeticBinaryExpression;
import io.prestosql.sql.tree.DereferenceExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.Identifier;
import io.prestosql.sql.tree.NodeRef;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.StringLiteral;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.testing.TestingSession;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;

@Test
/* loaded from: input_file:io/prestosql/spi/expression/TestPartialTranslator.class */
public class TestPartialTranslator {
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final TypeAnalyzer TYPE_ANALYZER = new TypeAnalyzer(new SqlParser(), METADATA);
    private static final TypeProvider TYPE_PROVIDER = TypeProvider.copyOf(ImmutableMap.builder().put(new Symbol("double_symbol_1"), DoubleType.DOUBLE).put(new Symbol("double_symbol_2"), DoubleType.DOUBLE).put(new Symbol("bigint_symbol_1"), BigintType.BIGINT).put(new Symbol("row_symbol_1"), RowType.rowType(new RowType.Field[]{RowType.field("int_symbol_1", IntegerType.INTEGER), RowType.field("varchar_symbol_1", VarcharType.createVarcharType(5))})).build());

    @Test
    public void testPartialTranslator() {
        SymbolReference symbolReference = new SymbolReference("row_symbol_1");
        DereferenceExpression dereferenceExpression = new DereferenceExpression(symbolReference, new Identifier("int_symbol_1"));
        DereferenceExpression dereferenceExpression2 = new DereferenceExpression(symbolReference, new Identifier("varchar_symbol_1"));
        StringLiteral stringLiteral = new StringLiteral("abcd");
        SymbolReference symbolReference2 = new SymbolReference("double_symbol_1");
        assertFullTranslation(symbolReference2);
        assertFullTranslation(dereferenceExpression);
        assertFullTranslation(stringLiteral);
        assertPartialTranslation(new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, symbolReference2, dereferenceExpression), ImmutableList.of(symbolReference2, dereferenceExpression));
        ImmutableList of = ImmutableList.of(stringLiteral, dereferenceExpression2);
        assertPartialTranslation(new FunctionCall(QualifiedName.of("concat"), of), of);
    }

    private void assertPartialTranslation(Expression expression, List<Expression> list) {
        Map extractPartialTranslations = PartialTranslator.extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER);
        Assert.assertEquals(list.size(), extractPartialTranslations.size());
        for (Expression expression2 : list) {
            Assert.assertEquals(extractPartialTranslations.get(NodeRef.of(expression2)), ConnectorExpressionTranslator.translate(TEST_SESSION, expression2, TYPE_ANALYZER, TYPE_PROVIDER).get());
        }
    }

    private void assertFullTranslation(Expression expression) {
        Map extractPartialTranslations = PartialTranslator.extractPartialTranslations(expression, TEST_SESSION, TYPE_ANALYZER, TYPE_PROVIDER);
        Assert.assertEquals(Iterables.getOnlyElement(extractPartialTranslations.keySet()), NodeRef.of(expression));
        Assert.assertEquals(Iterables.getOnlyElement(extractPartialTranslations.values()), ConnectorExpressionTranslator.translate(TEST_SESSION, expression, TYPE_ANALYZER, TYPE_PROVIDER).get());
    }
}
