package io.prestosql.type;

import com.google.common.base.Throwables;
import io.prestosql.operator.scalar.AbstractTestFunctions;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.function.BlockIndex;
import io.prestosql.spi.function.BlockPosition;
import io.prestosql.spi.function.Convention;
import io.prestosql.spi.function.FunctionDependency;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.function.ScalarFunction;
import io.prestosql.spi.function.SqlType;
import io.prestosql.spi.type.IntegerType;
import java.lang.invoke.MethodHandle;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/type/TestConventionDependencies.class */
public class TestConventionDependencies extends AbstractTestFunctions {

    @ScalarFunction("add")
    /* loaded from: input_file:io/prestosql/type/TestConventionDependencies$Add.class */
    public static class Add {
        @SqlType("integer")
        public static long add(@SqlType("integer") long j, @SqlType("integer") long j2) {
            return Math.addExact((int) j, (int) j2);
        }

        @SqlType("integer")
        public static long addBlockPosition(@SqlType("integer") long j, @SqlType(value = "integer", nativeContainerType = long.class) @BlockPosition Block block, @BlockIndex int i) {
            return Math.addExact((int) j, (int) IntegerType.INTEGER.getLong(block, i));
        }
    }

    @ScalarFunction("block_position_convention")
    /* loaded from: input_file:io/prestosql/type/TestConventionDependencies$BlockPositionConvention.class */
    public static class BlockPositionConvention {
        @SqlType("integer")
        public static long testBlockPositionConvention(@FunctionDependency(name = "add", returnType = "integer", argumentTypes = {"integer", "integer"}, convention = @Convention(arguments = {InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION}, result = InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL)) MethodHandle methodHandle, @SqlType("array(int)") Block block) {
            long j = 0;
            for (int i = 0; i < block.getPositionCount(); i++) {
                try {
                    j = (long) methodHandle.invokeExact(j, block, i);
                } catch (Throwable th) {
                    Throwables.throwIfInstanceOf(th, Error.class);
                    Throwables.throwIfInstanceOf(th, PrestoException.class);
                    throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, th);
                }
            }
            return j;
        }
    }

    @ScalarFunction("regular_convention")
    /* loaded from: input_file:io/prestosql/type/TestConventionDependencies$RegularConvention.class */
    public static class RegularConvention {
        @SqlType("integer")
        public static long testRegularConvention(@FunctionDependency(name = "add", returnType = "integer", argumentTypes = {"integer", "integer"}, convention = @Convention(arguments = {InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL}, result = InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL)) MethodHandle methodHandle, @SqlType("integer") long j, @SqlType("integer") long j2) {
            try {
                return (long) methodHandle.invokeExact(j, j2);
            } catch (Throwable th) {
                Throwables.throwIfInstanceOf(th, Error.class);
                Throwables.throwIfInstanceOf(th, PrestoException.class);
                throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, th);
            }
        }
    }

    @BeforeClass
    public void setUp() {
        registerParametricScalar(RegularConvention.class);
        registerParametricScalar(BlockPositionConvention.class);
        registerParametricScalar(Add.class);
    }

    @Test
    public void testConventionDependencies() {
        assertFunction("regular_convention(1, 1)", IntegerType.INTEGER, 2);
        assertFunction("regular_convention(50, 10)", IntegerType.INTEGER, 60);
        assertFunction("regular_convention(1, 0)", IntegerType.INTEGER, 1);
        assertFunction("block_position_convention(ARRAY [1, 2, 3])", IntegerType.INTEGER, 6);
        assertFunction("block_position_convention(ARRAY [25, 0, 5])", IntegerType.INTEGER, 30);
        assertFunction("block_position_convention(ARRAY [56, 275, 36])", IntegerType.INTEGER, 367);
    }
}
