package io.trino.cost;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.ImmutableLongArray;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.operator.RetryPolicy;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.testing.TestingHandles;
import io.trino.testing.TestingMetadata;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/cost/TestRemoteSourceStatsRule.class */
public class TestRemoteSourceStatsRule extends BaseStatsCalculatorTest {
    @Test
    public void testStatsRule() {
        assertRemoteSourceStats(0.1d, 3325.333333d);
    }

    @Test
    public void testStatsRuleWithNaNNullFraction() {
        assertRemoteSourceStats(Double.NaN, 2992.0d);
    }

    private void assertRemoteSourceStats(double d, double d2) {
        tester().assertStatsFor(planBuilder -> {
            return planBuilder.remoteSource(ImmutableList.of(new PlanFragmentId("fragment")), ImmutableList.of(planBuilder.symbol("col_a", VarcharType.VARCHAR), planBuilder.symbol("col_b", VarcharType.VARCHAR), planBuilder.symbol("col_c", BigintType.BIGINT), planBuilder.symbol("col_d", DoubleType.DOUBLE)), Optional.empty(), ExchangeNode.Type.REPARTITION, RetryPolicy.TASK);
        }).withRuntimeInfoProvider(createRuntimeInfoProvider(createStatsAndCosts(d))).check(planNodeStatsAssertion -> {
            planNodeStatsAssertion.outputRowsCount(1000000.0d).symbolStats(new Symbol(VarcharType.VARCHAR, "col_a"), symbolStatsAssertion -> {
                symbolStatsAssertion.averageRowSize(d2).distinctValuesCount(100.0d).nullsFraction(d).lowValueUnknown().highValueUnknown();
            }).symbolStats(new Symbol(VarcharType.VARCHAR, "col_b"), symbolStatsAssertion2 -> {
                symbolStatsAssertion2.averageRowSize(d2).distinctValuesCount(233.0d).nullsFraction(d).lowValueUnknown().highValueUnknown();
            }).symbolStats(new Symbol(BigintType.BIGINT, "col_c"), symbolStatsAssertion3 -> {
                symbolStatsAssertion3.averageRowSize(Double.NaN).distinctValuesCount(98.0d).nullsFraction(d).highValue(100.0d).lowValue(3.0d);
            }).symbolStats(new Symbol(DoubleType.DOUBLE, "col_d"), symbolStatsAssertion4 -> {
                symbolStatsAssertion4.averageRowSize(Double.NaN).distinctValuesCount(300.0d).nullsFraction(d).highValue(100.0d).lowValue(3.0d);
            });
        });
    }

    private RuntimeInfoProvider createRuntimeInfoProvider(StatsAndCosts statsAndCosts) {
        PlanFragment createPlanFragment = createPlanFragment(statsAndCosts);
        return new StaticRuntimeInfoProvider(ImmutableMap.of(createPlanFragment.getId(), createRuntimeOutputStatsEstimate()), ImmutableMap.of(createPlanFragment.getId(), createPlanFragment));
    }

    private OutputStatsEstimator.OutputStatsEstimateResult createRuntimeOutputStatsEstimate() {
        return new OutputStatsEstimator.OutputStatsEstimateResult(ImmutableLongArray.of(1000000000L, 2000000000L, 3000000000L), 1000000L, "FINISHED", true);
    }

    private PlanFragment createPlanFragment(StatsAndCosts statsAndCosts) {
        return new PlanFragment(new PlanFragmentId("fragment"), TableScanNode.newInstance(new PlanNodeId("plan_id"), TestingHandles.TEST_TABLE_HANDLE, ImmutableList.of(new Symbol(VarcharType.VARCHAR, "col_a"), new Symbol(VarcharType.VARCHAR, "col_b"), new Symbol(BigintType.BIGINT, "col_c"), new Symbol(DoubleType.DOUBLE, "col_d")), ImmutableMap.of(new Symbol(VarcharType.VARCHAR, "col_a"), new TestingMetadata.TestingColumnHandle("col_a", 0, VarcharType.VARCHAR), new Symbol(VarcharType.VARCHAR, "col_b"), new TestingMetadata.TestingColumnHandle("col_b", 1, VarcharType.VARCHAR), new Symbol(BigintType.BIGINT, "col_c"), new TestingMetadata.TestingColumnHandle("col_c", 2, BigintType.BIGINT), new Symbol(DoubleType.DOUBLE, "col_d"), new TestingMetadata.TestingColumnHandle("col_d", 3, DoubleType.DOUBLE)), false, Optional.empty()), ImmutableSet.of(new Symbol(VarcharType.VARCHAR, "col_a"), new Symbol(VarcharType.VARCHAR, "col_b"), new Symbol(BigintType.BIGINT, "col_c"), new Symbol(DoubleType.DOUBLE, "col_d")), SystemPartitioningHandle.SOURCE_DISTRIBUTION, Optional.empty(), ImmutableList.of(new PlanNodeId("plan_id")), new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(new Symbol(BigintType.BIGINT, "col_c"))), statsAndCosts, ImmutableList.of(), ImmutableMap.of(), Optional.empty());
    }

    private StatsAndCosts createStatsAndCosts(double d) {
        return new StatsAndCosts(ImmutableMap.of(new PlanNodeId("plan_id"), new PlanNodeStatsEstimate(10000.0d, ImmutableMap.of(new Symbol(VarcharType.VARCHAR, "col_a"), SymbolStatsEstimate.builder().setNullsFraction(d).setDistinctValuesCount(100.0d).build(), new Symbol(VarcharType.VARCHAR, "col_b"), SymbolStatsEstimate.builder().setNullsFraction(d).setDistinctValuesCount(233.0d).build(), new Symbol(BigintType.BIGINT, "col_c"), SymbolStatsEstimate.builder().setNullsFraction(d).setDistinctValuesCount(98.0d).setHighValue(100.0d).setLowValue(3.0d).build(), new Symbol(DoubleType.DOUBLE, "col_d"), SymbolStatsEstimate.builder().setNullsFraction(d).setDistinctValuesCount(300.0d).setHighValue(100.0d).setLowValue(3.0d).build()))), ImmutableMap.of());
    }
}
