package io.prestosql.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.VariableInstruction;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlScalarFunction;
import io.prestosql.spi.PageBuilder;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.TypeSignatureParameter;
import io.prestosql.sql.gen.CallSiteBinder;
import io.prestosql.sql.gen.SqlTypeBytecodeExpression;
import io.prestosql.sql.gen.lambda.UnaryFunctionInterface;
import io.prestosql.type.FunctionType;
import io.prestosql.type.UnknownType;
import io.prestosql.util.CompilerUtils;
import io.prestosql.util.Reflection;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/prestosql/operator/scalar/ArrayTransformFunction.class */
public final class ArrayTransformFunction extends SqlScalarFunction {
    public static final ArrayTransformFunction ARRAY_TRANSFORM_FUNCTION = new ArrayTransformFunction();

    private ArrayTransformFunction() {
        super(new FunctionMetadata(new Signature("transform", ImmutableList.of(Signature.typeVariable("T"), Signature.typeVariable("U")), ImmutableList.of(), TypeSignature.arrayType(new TypeSignature("U", new TypeSignatureParameter[0])), ImmutableList.of(TypeSignature.arrayType(new TypeSignature("T", new TypeSignatureParameter[0])), TypeSignature.functionType(new TypeSignature("T", new TypeSignatureParameter[0]), new TypeSignature[]{new TypeSignature("U", new TypeSignatureParameter[0])})), false), false, ImmutableList.of(new FunctionArgumentDefinition(false), new FunctionArgumentDefinition(false)), false, false, "Apply lambda to each element of the array", FunctionKind.SCALAR));
    }

    @Override // io.prestosql.metadata.SqlScalarFunction
    protected ScalarFunctionImplementation specialize(FunctionBinding functionBinding) {
        Class<?> generateTransform = generateTransform(functionBinding.getTypeVariable("T"), functionBinding.getTypeVariable("U"));
        return new ScalarFunctionImplementation(InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, ImmutableList.of(InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.FUNCTION), ImmutableList.of(Optional.empty(), Optional.of(UnaryFunctionInterface.class)), Reflection.methodHandle(generateTransform, "transform", PageBuilder.class, Block.class, UnaryFunctionInterface.class), Optional.of(Reflection.methodHandle(generateTransform, "createPageBuilder", new Class[0])));
    }

    private static Class<?> generateTransform(Type type, Type type2) {
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        Class wrap = Primitives.wrap(type.getJavaType());
        Class wrap2 = Primitives.wrap(type2.getJavaType());
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("ArrayTransform"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        classDefinition.declareDefaultConstructor(Access.a(new Access[]{Access.PRIVATE}));
        classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC, Access.STATIC}), "createPageBuilder", ParameterizedType.type(PageBuilder.class), new Parameter[0]).getBody().append(BytecodeExpressions.newInstance(PageBuilder.class, new BytecodeExpression[]{SqlTypeBytecodeExpression.constantType(callSiteBinder, new ArrayType(type2)).invoke("getTypeParameters", List.class, new BytecodeExpression[0])}).ret());
        Parameter arg = Parameter.arg("pageBuilder", PageBuilder.class);
        BytecodeExpression arg2 = Parameter.arg("block", Block.class);
        Parameter arg3 = Parameter.arg(FunctionType.NAME, UnaryFunctionInterface.class);
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC, Access.STATIC}), "transform", ParameterizedType.type(Block.class), ImmutableList.of(arg, arg2, arg3));
        BytecodeBlock body = declareMethod.getBody();
        Scope scope = declareMethod.getScope();
        BytecodeExpression declareVariable = scope.declareVariable(Integer.TYPE, "positionCount");
        BytecodeExpression declareVariable2 = scope.declareVariable(Integer.TYPE, "position");
        BytecodeExpression declareVariable3 = scope.declareVariable(BlockBuilder.class, "blockBuilder");
        Variable declareVariable4 = scope.declareVariable(wrap, "inputElement");
        Variable declareVariable5 = scope.declareVariable(wrap2, "outputElement");
        body.append(declareVariable.set(arg2.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0])));
        body.append(new IfStatement().condition(arg.invoke("isFull", Boolean.TYPE, new BytecodeExpression[0])).ifTrue(arg.invoke("reset", Void.TYPE, new BytecodeExpression[0])));
        body.append(declareVariable3.set(arg.invoke("getBlockBuilder", BlockBuilder.class, new BytecodeExpression[]{BytecodeExpressions.constantInt(0)})));
        body.append(new ForLoop().initialize(declareVariable2.set(BytecodeExpressions.constantInt(0))).condition(BytecodeExpressions.lessThan(declareVariable2, declareVariable)).update(VariableInstruction.incrementVariable(declareVariable2, (byte) 1)).body(new BytecodeBlock().append(!type.equals(UnknownType.UNKNOWN) ? new IfStatement().condition(arg2.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{declareVariable2})).ifTrue(declareVariable4.set(BytecodeExpressions.constantNull(wrap))).ifFalse(declareVariable4.set(SqlTypeBytecodeExpression.constantType(callSiteBinder, type).getValue(arg2, declareVariable2).cast(wrap))) : new BytecodeBlock().append(declareVariable4.set(BytecodeExpressions.constantNull(wrap)))).append(declareVariable5.set(arg3.invoke("apply", Object.class, new BytecodeExpression[]{declareVariable4.cast(Object.class)}).cast(wrap2))).append(!type2.equals(UnknownType.UNKNOWN) ? new IfStatement().condition(BytecodeExpressions.equal(declareVariable5, BytecodeExpressions.constantNull(wrap2))).ifTrue(declareVariable3.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop()).ifFalse(SqlTypeBytecodeExpression.constantType(callSiteBinder, type2).writeValue(declareVariable3, declareVariable5.cast(type2.getJavaType()))) : new BytecodeBlock().append(declareVariable3.invoke("appendNull", BlockBuilder.class, new BytecodeExpression[0]).pop()))));
        body.append(arg.invoke("declarePositions", Void.TYPE, new BytecodeExpression[]{declareVariable}));
        body.append(declareVariable3.invoke("getRegion", Block.class, new BytecodeExpression[]{BytecodeExpressions.subtract(declareVariable3.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0]), declareVariable), declareVariable}).ret());
        return CompilerUtils.defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), ArrayTransformFunction.class.getClassLoader());
    }
}
