package io.prestosql.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.slice.Slice;
import io.prestosql.metadata.BoundVariables;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.SignatureBinder;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AggregationMetadata;
import io.prestosql.operator.aggregation.state.LongDecimalWithOverflowState;
import io.prestosql.operator.aggregation.state.LongDecimalWithOverflowStateFactory;
import io.prestosql.operator.aggregation.state.LongDecimalWithOverflowStateSerializer;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.type.DecimalType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.UnscaledDecimal128Arithmetic;
import io.prestosql.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.List;

/* loaded from: input_file:io/prestosql/operator/aggregation/DecimalSumAggregation.class */
public class DecimalSumAggregation extends SqlAggregationFunction {
    private static final String NAME = "sum";
    public static final DecimalSumAggregation DECIMAL_SUM_AGGREGATION = new DecimalSumAggregation();
    private static final MethodHandle SHORT_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "inputShortDecimal", Type.class, LongDecimalWithOverflowState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_INPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "inputLongDecimal", Type.class, LongDecimalWithOverflowState.class, Block.class, Integer.TYPE);
    private static final MethodHandle LONG_DECIMAL_OUTPUT_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "outputLongDecimal", DecimalType.class, LongDecimalWithOverflowState.class, BlockBuilder.class);
    private static final MethodHandle COMBINE_FUNCTION = Reflection.methodHandle(DecimalSumAggregation.class, "combine", LongDecimalWithOverflowState.class, LongDecimalWithOverflowState.class);

    public DecimalSumAggregation() {
        super(NAME, ImmutableList.of(), ImmutableList.of(), TypeSignature.parseTypeSignature("decimal(38,s)", ImmutableSet.of("s")), ImmutableList.of(TypeSignature.parseTypeSignature("decimal(p,s)", ImmutableSet.of("p", "s"))), FunctionKind.AGGREGATE);
    }

    @Override // io.prestosql.metadata.SqlFunction
    public String getDescription() {
        return "Calculates the sum over the input values";
    }

    @Override // io.prestosql.metadata.SqlAggregationFunction
    public InternalAggregationFunction specialize(BoundVariables boundVariables, int i, Metadata metadata) {
        return generateAggregation(metadata.getType((TypeSignature) Iterables.getOnlyElement(SignatureBinder.applyBoundVariables(getSignature().getArgumentTypes(), boundVariables))), metadata.getType(SignatureBinder.applyBoundVariables(getSignature().getReturnType(), boundVariables)));
    }

    private static InternalAggregationFunction generateAggregation(Type type, Type type2) {
        Preconditions.checkArgument(type instanceof DecimalType, "type must be Decimal");
        DynamicClassLoader dynamicClassLoader = new DynamicClassLoader(DecimalSumAggregation.class.getClassLoader());
        ImmutableList of = ImmutableList.of(type);
        LongDecimalWithOverflowStateSerializer longDecimalWithOverflowStateSerializer = new LongDecimalWithOverflowStateSerializer();
        AggregationMetadata aggregationMetadata = new AggregationMetadata(AggregationUtils.generateAggregationName(NAME, type2.getTypeSignature(), (List) of.stream().map((v0) -> {
            return v0.getTypeSignature();
        }).collect(ImmutableList.toImmutableList())), createInputParameterMetadata(type), (((DecimalType) type).isShort() ? SHORT_DECIMAL_INPUT_FUNCTION : LONG_DECIMAL_INPUT_FUNCTION).bindTo(type), COMBINE_FUNCTION, LONG_DECIMAL_OUTPUT_FUNCTION.bindTo(type2), ImmutableList.of(new AggregationMetadata.AccumulatorStateDescriptor(LongDecimalWithOverflowState.class, longDecimalWithOverflowStateSerializer, new LongDecimalWithOverflowStateFactory())), type2);
        return new InternalAggregationFunction(NAME, of, ImmutableList.of(longDecimalWithOverflowStateSerializer.getSerializedType()), type2, true, false, AccumulatorCompiler.generateAccumulatorFactoryBinder(aggregationMetadata, dynamicClassLoader));
    }

    private static List<AggregationMetadata.ParameterMetadata> createInputParameterMetadata(Type type) {
        return ImmutableList.of(new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.STATE), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL, type), new AggregationMetadata.ParameterMetadata(AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX));
    }

    public static void inputShortDecimal(Type type, LongDecimalWithOverflowState longDecimalWithOverflowState, Block block, int i) {
        accumulateValueInState(UnscaledDecimal128Arithmetic.unscaledDecimal(type.getLong(block, i)), longDecimalWithOverflowState);
    }

    public static void inputLongDecimal(Type type, LongDecimalWithOverflowState longDecimalWithOverflowState, Block block, int i) {
        accumulateValueInState(type.getSlice(block, i), longDecimalWithOverflowState);
    }

    private static void accumulateValueInState(Slice slice, LongDecimalWithOverflowState longDecimalWithOverflowState) {
        initializeIfNeeded(longDecimalWithOverflowState);
        Slice longDecimal = longDecimalWithOverflowState.getLongDecimal();
        longDecimalWithOverflowState.setOverflow(longDecimalWithOverflowState.getOverflow() + UnscaledDecimal128Arithmetic.addWithOverflow(longDecimal, slice, longDecimal));
    }

    private static void initializeIfNeeded(LongDecimalWithOverflowState longDecimalWithOverflowState) {
        if (longDecimalWithOverflowState.getLongDecimal() == null) {
            longDecimalWithOverflowState.setLongDecimal(UnscaledDecimal128Arithmetic.unscaledDecimal());
        }
    }

    public static void combine(LongDecimalWithOverflowState longDecimalWithOverflowState, LongDecimalWithOverflowState longDecimalWithOverflowState2) {
        longDecimalWithOverflowState.setOverflow(longDecimalWithOverflowState.getOverflow() + longDecimalWithOverflowState2.getOverflow());
        if (longDecimalWithOverflowState.getLongDecimal() == null) {
            longDecimalWithOverflowState.setLongDecimal(longDecimalWithOverflowState2.getLongDecimal());
        } else {
            accumulateValueInState(longDecimalWithOverflowState2.getLongDecimal(), longDecimalWithOverflowState);
        }
    }

    public static void outputLongDecimal(DecimalType decimalType, LongDecimalWithOverflowState longDecimalWithOverflowState, BlockBuilder blockBuilder) {
        if (longDecimalWithOverflowState.getLongDecimal() == null) {
            blockBuilder.appendNull();
            return;
        }
        if (longDecimalWithOverflowState.getOverflow() != 0) {
            UnscaledDecimal128Arithmetic.throwOverflowException();
        }
        UnscaledDecimal128Arithmetic.throwIfOverflows(longDecimalWithOverflowState.getLongDecimal());
        decimalType.writeSlice(blockBuilder, longDecimalWithOverflowState.getLongDecimal());
    }
}
