package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slices;
import io.airlift.stats.TDigest;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.tree.QualifiedName;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.stream.LongStream;
import org.junit.jupiter.api.Test;
import org.testng.Assert;

/* loaded from: input_file:io/trino/operator/aggregation/TestTDigestAggregationFunction.class */
public class TestTDigestAggregationFunction {
    private static final BiFunction<Object, Object, Boolean> TDIGEST_EQUALITY = (obj, obj2) -> {
        if (obj == null && obj2 == null) {
            return true;
        }
        Objects.requireNonNull(obj, "actual value was null");
        Objects.requireNonNull(obj2, "expected value was null");
        TDigest deserialize = TDigest.deserialize(Slices.wrappedBuffer(((SqlVarbinary) obj).getBytes()));
        TDigest deserialize2 = TDigest.deserialize(Slices.wrappedBuffer(((SqlVarbinary) obj2).getBytes()));
        return Boolean.valueOf(deserialize.getMin() == deserialize2.getMin() && deserialize.getMax() == deserialize2.getMax() && returnSimilarResults(deserialize, deserialize2, (deserialize.getMax() - deserialize.getMin()) / 1000.0d));
    };
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();

    @Test
    public void testTdigestAggregationFunction() {
        ImmutableList of = ImmutableList.of(Double.valueOf(1.5d), Double.valueOf(2.0d), Double.valueOf(1.1d), Double.valueOf(1.111d), Double.valueOf(3.5d), Double.valueOf(4.4d), Double.valueOf(4.4d), Double.valueOf(1.0d), Double.valueOf(9.9d), Double.valueOf(9.0d));
        testAggregation(BlockAssertions.createDoublesBlock(Double.valueOf(1.0d), null, Double.valueOf(2.0d), null, Double.valueOf(3.0d), null, Double.valueOf(4.0d), null, Double.valueOf(5.0d), null), BlockAssertions.createDoublesBlock((Iterable<Double>) of), ImmutableList.of(Double.valueOf(1.5d), Double.valueOf(1.1d), Double.valueOf(3.5d), Double.valueOf(4.4d), Double.valueOf(9.9d)), 1.0d, 2.0d, 3.0d, 4.0d, 5.0d);
        testAggregation(BlockAssertions.createDoublesBlock(null, null, null, null, null), BlockAssertions.createRepeatedValuesBlock(1.0d, 5), ImmutableList.of(), new double[0]);
        testAggregation(BlockAssertions.createDoublesBlock(Double.valueOf(-1.0d), Double.valueOf(-2.0d), Double.valueOf(-3.0d), Double.valueOf(-4.0d), Double.valueOf(-5.0d), Double.valueOf(-6.0d), Double.valueOf(-7.0d), Double.valueOf(-8.0d), Double.valueOf(-9.0d), Double.valueOf(-10.0d)), BlockAssertions.createDoublesBlock((Iterable<Double>) of), ImmutableList.copyOf(of), -1.0d, -2.0d, -3.0d, -4.0d, -5.0d, -6.0d, -7.0d, -8.0d, -9.0d, -10.0d);
        testAggregation(BlockAssertions.createDoublesBlock(Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(3.0d), Double.valueOf(4.0d), Double.valueOf(5.0d), Double.valueOf(6.0d), Double.valueOf(7.0d), Double.valueOf(8.0d), Double.valueOf(9.0d), Double.valueOf(10.0d)), BlockAssertions.createDoublesBlock((Iterable<Double>) of), ImmutableList.copyOf(of), 1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d, 9.0d, 10.0d);
        testAggregation(BlockAssertions.createDoublesBlock(new Double[0]), BlockAssertions.createRepeatedValuesBlock(1.0d, 0), ImmutableList.of(), new double[0]);
        testAggregation(BlockAssertions.createDoublesBlock(Double.valueOf(1.0d)), BlockAssertions.createRepeatedValuesBlock(1.1d, 1), ImmutableList.of(Double.valueOf(1.1d)), 1.0d);
        List list = (List) LongStream.range(-1000L, 1000L).asDoubleStream().map(d -> {
            return 2.0d - (d / 1000.0d);
        }).boxed().collect(ImmutableList.toImmutableList());
        testAggregation(TDIGEST_EQUALITY, BlockAssertions.createDoubleSequenceBlock(-1000, 1000), BlockAssertions.createDoublesBlock(list), ImmutableList.copyOf(list), LongStream.range(-1000L, 1000L).asDoubleStream().toArray());
    }

    private void testAggregation(Block block, Block block2, List<Double> list, double... dArr) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, QualifiedName.of("tdigest_agg"), (List<TypeSignatureProvider>) TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE}), getExpectedValue(Collections.nCopies(dArr.length, Double.valueOf(1.0d)), dArr), new Page(new Block[]{block}));
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, QualifiedName.of("tdigest_agg"), (List<TypeSignatureProvider>) TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE, DoubleType.DOUBLE}), getExpectedValue(list, dArr), new Page(new Block[]{block, block2}));
    }

    private void testAggregation(BiFunction<Object, Object, Boolean> biFunction, Block block, Block block2, List<Double> list, double... dArr) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, QualifiedName.of("tdigest_agg"), TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE}), biFunction, "Test multiple values", new Page(new Block[]{block}), getExpectedValue(Collections.nCopies(dArr.length, Double.valueOf(1.0d)), dArr));
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, QualifiedName.of("tdigest_agg"), TypeSignatureProvider.fromTypes(new Type[]{DoubleType.DOUBLE, DoubleType.DOUBLE}), biFunction, "Test multiple values", new Page(new Block[]{block, block2}), getExpectedValue(list, dArr));
    }

    private Object getExpectedValue(List<Double> list, double... dArr) {
        Assert.assertEquals(list.size(), dArr.length, "mismatched weights and values");
        if (dArr.length == 0) {
            return null;
        }
        TDigest tDigest = new TDigest();
        for (int i = 0; i < list.size(); i++) {
            tDigest.add(dArr[i], list.get(i).doubleValue());
        }
        return new SqlVarbinary(tDigest.serialize().getBytes());
    }

    private static boolean returnSimilarResults(TDigest tDigest, TDigest tDigest2, double d) {
        for (double d2 : new double[]{1.0E-4d, 0.001d, 0.01d, 0.1d, 0.5d, 0.567d, 0.89d, 0.999d}) {
            if (Math.abs(tDigest.valueAt(d2) - tDigest2.valueAt(d2)) > d) {
                return false;
            }
        }
        return true;
    }
}
