package io.prestosql.cost;

import com.google.common.collect.ImmutableMap;
import io.prestosql.Session;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.DoubleLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.StringLiteral;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.testing.TestingSession;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

/* loaded from: input_file:io/prestosql/cost/TestComparisonStatsCalculator.class */
public class TestComparisonStatsCalculator {
    private FilterStatsCalculator filterStatsCalculator;
    private Session session;
    private PlanNodeStatsEstimate standardInputStatistics;
    private TypeProvider types;
    private SymbolStatsEstimate uStats;
    private SymbolStatsEstimate wStats;
    private SymbolStatsEstimate xStats;
    private SymbolStatsEstimate yStats;
    private SymbolStatsEstimate zStats;
    private SymbolStatsEstimate leftOpenStats;
    private SymbolStatsEstimate rightOpenStats;
    private SymbolStatsEstimate unknownRangeStats;
    private SymbolStatsEstimate emptyRangeStats;
    private SymbolStatsEstimate varcharStats;

    @BeforeClass
    public void setUp() {
        this.session = TestingSession.testSessionBuilder().build();
        MetadataManager createTestMetadataManager = MetadataManager.createTestMetadataManager();
        this.filterStatsCalculator = new FilterStatsCalculator(createTestMetadataManager, new ScalarStatsCalculator(createTestMetadataManager), new StatsNormalizer());
        this.uStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0d).setDistinctValuesCount(300.0d).setLowValue(0.0d).setHighValue(20.0d).setNullsFraction(0.1d).build();
        this.wStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0d).setDistinctValuesCount(30.0d).setLowValue(0.0d).setHighValue(20.0d).setNullsFraction(0.1d).build();
        this.xStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(40.0d).setLowValue(-10.0d).setHighValue(10.0d).setNullsFraction(0.25d).build();
        this.yStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(20.0d).setLowValue(0.0d).setHighValue(5.0d).setNullsFraction(0.5d).build();
        this.zStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(5.0d).setLowValue(-100.0d).setHighValue(100.0d).setNullsFraction(0.1d).build();
        this.leftOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(15.0d).setNullsFraction(0.1d).build();
        this.rightOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(-15.0d).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1d).build();
        this.unknownRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1d).build();
        this.emptyRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(0.0d).setDistinctValuesCount(0.0d).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0d).build();
        this.varcharStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1d).build();
        this.standardInputStatistics = PlanNodeStatsEstimate.builder().addSymbolStatistics(new Symbol("u"), this.uStats).addSymbolStatistics(new Symbol("w"), this.wStats).addSymbolStatistics(new Symbol("x"), this.xStats).addSymbolStatistics(new Symbol("y"), this.yStats).addSymbolStatistics(new Symbol("z"), this.zStats).addSymbolStatistics(new Symbol("leftOpen"), this.leftOpenStats).addSymbolStatistics(new Symbol("rightOpen"), this.rightOpenStats).addSymbolStatistics(new Symbol("unknownRange"), this.unknownRangeStats).addSymbolStatistics(new Symbol("emptyRange"), this.emptyRangeStats).addSymbolStatistics(new Symbol("varchar"), this.varcharStats).setOutputRowCount(1000.0d).build();
        this.types = TypeProvider.copyOf(ImmutableMap.builder().put(new Symbol("u"), DoubleType.DOUBLE).put(new Symbol("w"), DoubleType.DOUBLE).put(new Symbol("x"), DoubleType.DOUBLE).put(new Symbol("y"), DoubleType.DOUBLE).put(new Symbol("z"), DoubleType.DOUBLE).put(new Symbol("leftOpen"), DoubleType.DOUBLE).put(new Symbol("rightOpen"), DoubleType.DOUBLE).put(new Symbol("unknownRange"), DoubleType.DOUBLE).put(new Symbol("emptyRange"), DoubleType.DOUBLE).put(new Symbol("varchar"), VarcharType.createVarcharType(10)).build());
    }

    private Consumer<SymbolStatsAssertion> equalTo(SymbolStatsEstimate symbolStatsEstimate) {
        return symbolStatsAssertion -> {
            symbolStatsAssertion.lowValue(symbolStatsEstimate.getLowValue()).highValue(symbolStatsEstimate.getHighValue()).distinctValuesCount(symbolStatsEstimate.getDistinctValuesCount()).nullsFraction(symbolStatsEstimate.getNullsFraction());
        };
    }

    private SymbolStatsEstimate updateNDV(SymbolStatsEstimate symbolStatsEstimate, double d) {
        return symbolStatsEstimate.mapDistinctValuesCount(d2 -> {
            return Double.valueOf(d2.doubleValue() + d);
        });
    }

    private SymbolStatsEstimate capNDV(SymbolStatsEstimate symbolStatsEstimate, double d) {
        double distinctValuesCount = symbolStatsEstimate.getDistinctValuesCount();
        double nullsFraction = symbolStatsEstimate.getNullsFraction();
        return (Double.isNaN(distinctValuesCount) || Double.isNaN(d) || Double.isNaN(nullsFraction)) ? symbolStatsEstimate : distinctValuesCount <= d * (1.0d - nullsFraction) ? symbolStatsEstimate : symbolStatsEstimate.mapDistinctValuesCount(d2 -> {
            return Double.valueOf((Math.min(distinctValuesCount, d) + (d * (1.0d - nullsFraction))) / 2.0d);
        }).mapNullsFraction(d3 -> {
            return Double.valueOf(nullsFraction / 2.0d);
        });
    }

    private SymbolStatsEstimate zeroNullsFraction(SymbolStatsEstimate symbolStatsEstimate) {
        return symbolStatsEstimate.mapNullsFraction(d -> {
            return Double.valueOf(0.0d);
        });
    }

    private PlanNodeStatsAssertion assertCalculate(Expression expression) {
        return PlanNodeStatsAssertion.assertThat(this.filterStatsCalculator.filterStats(this.standardInputStatistics, expression, this.session, this.types));
    }

    @Test
    public void verifyTestInputConsistent() {
        checkConsistent(new StatsNormalizer(), "standardInputStatistics", this.standardInputStatistics, this.standardInputStatistics.getSymbolsWithKnownStatistics(), this.types);
    }

    @Test
    public void symbolToLiteralEqualStats() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))).outputRowsCount(25.0d).symbolStats("y", symbolStatsAssertion -> {
            symbolStatsAssertion.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(2.5d).highValue(2.5d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))).outputRowsCount(18.75d).symbolStats("x", symbolStatsAssertion2 -> {
            symbolStatsAssertion2.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(10.0d).highValue(10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))).outputRowsCount(0.0d).symbolStats("y", symbolStatsAssertion3 -> {
            symbolStatsAssertion3.averageRowSize(0.0d).distinctValuesCount(0.0d).emptyRange().nullsFraction(1.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))).outputRowsCount(18.0d).symbolStats("leftOpen", symbolStatsAssertion4 -> {
            symbolStatsAssertion4.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(2.5d).highValue(2.5d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))).outputRowsCount(18.0d).symbolStats("rightOpen", symbolStatsAssertion5 -> {
            symbolStatsAssertion5.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(-2.5d).highValue(-2.5d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))).outputRowsCount(18.0d).symbolStats("unknownRange", symbolStatsAssertion6 -> {
            symbolStatsAssertion6.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(0.0d).highValue(0.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))).outputRowsCount(0.0d).symbolStats("emptyRange", equalTo(this.emptyRangeStats));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("varchar"), new StringLiteral("blah"))).outputRowsCount(18.0d).symbolStats("varchar", symbolStatsAssertion7 -> {
            symbolStatsAssertion7.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(Double.NEGATIVE_INFINITY).highValue(Double.POSITIVE_INFINITY).nullsFraction(0.0d);
        });
    }

    @Test
    public void symbolToLiteralNotEqualStats() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("2.5"))).outputRowsCount(475.0d).symbolStats("y", symbolStatsAssertion -> {
            symbolStatsAssertion.averageRowSize(4.0d).distinctValuesCount(19.0d).lowValue(0.0d).highValue(5.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("x"), new DoubleLiteral("10.0"))).outputRowsCount(731.25d).symbolStats("x", symbolStatsAssertion2 -> {
            symbolStatsAssertion2.averageRowSize(4.0d).distinctValuesCount(39.0d).lowValue(-10.0d).highValue(10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("y"), new DoubleLiteral("10.0"))).outputRowsCount(500.0d).symbolStats("y", symbolStatsAssertion3 -> {
            symbolStatsAssertion3.averageRowSize(4.0d).distinctValuesCount(19.0d).lowValue(0.0d).highValue(5.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("leftOpen"), new DoubleLiteral("2.5"))).outputRowsCount(882.0d).symbolStats("leftOpen", symbolStatsAssertion4 -> {
            symbolStatsAssertion4.averageRowSize(4.0d).distinctValuesCount(49.0d).lowValueUnknown().highValue(15.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("rightOpen"), new DoubleLiteral("-2.5"))).outputRowsCount(882.0d).symbolStats("rightOpen", symbolStatsAssertion5 -> {
            symbolStatsAssertion5.averageRowSize(4.0d).distinctValuesCount(49.0d).lowValue(-15.0d).highValueUnknown().nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))).outputRowsCount(882.0d).symbolStats("unknownRange", symbolStatsAssertion6 -> {
            symbolStatsAssertion6.averageRowSize(4.0d).distinctValuesCount(49.0d).lowValueUnknown().highValueUnknown().nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))).outputRowsCount(0.0d).symbolStats("emptyRange", equalTo(this.emptyRangeStats));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("varchar"), new StringLiteral("blah"))).outputRowsCount(882.0d).symbolStats("varchar", symbolStatsAssertion7 -> {
            symbolStatsAssertion7.averageRowSize(4.0d).distinctValuesCount(49.0d).lowValueUnknown().highValueUnknown().nullsFraction(0.0d);
        });
    }

    @Test
    public void symbolToLiteralLessThanStats() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))).outputRowsCount(250.0d).symbolStats("y", symbolStatsAssertion -> {
            symbolStatsAssertion.averageRowSize(4.0d).distinctValuesCount(10.0d).lowValue(0.0d).highValue(2.5d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))).outputRowsCount(750.0d).symbolStats("x", symbolStatsAssertion2 -> {
            symbolStatsAssertion2.averageRowSize(4.0d).distinctValuesCount(40.0d).lowValue(-10.0d).highValue(10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))).outputRowsCount(18.75d).symbolStats("x", symbolStatsAssertion3 -> {
            symbolStatsAssertion3.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(-10.0d).highValue(-10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("y"), new DoubleLiteral("-10.0"))).outputRowsCount(0.0d).symbolStats("y", symbolStatsAssertion4 -> {
            symbolStatsAssertion4.averageRowSize(0.0d).distinctValuesCount(0.0d).emptyRange().nullsFraction(1.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))).outputRowsCount(450.0d).symbolStats("leftOpen", symbolStatsAssertion5 -> {
            symbolStatsAssertion5.averageRowSize(4.0d).distinctValuesCount(25.0d).lowValueUnknown().highValue(0.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))).outputRowsCount(225.0d).symbolStats("rightOpen", symbolStatsAssertion6 -> {
            symbolStatsAssertion6.averageRowSize(4.0d).distinctValuesCount(12.5d).lowValue(-15.0d).highValue(0.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))).outputRowsCount(450.0d).symbolStats("unknownRange", symbolStatsAssertion7 -> {
            symbolStatsAssertion7.averageRowSize(4.0d).distinctValuesCount(25.0d).lowValueUnknown().highValue(0.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))).outputRowsCount(0.0d).symbolStats("emptyRange", equalTo(this.emptyRangeStats));
    }

    @Test
    public void symbolToLiteralGreaterThanStats() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("2.5"))).outputRowsCount(250.0d).symbolStats("y", symbolStatsAssertion -> {
            symbolStatsAssertion.averageRowSize(4.0d).distinctValuesCount(10.0d).lowValue(2.5d).highValue(5.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("-10.0"))).outputRowsCount(750.0d).symbolStats("x", symbolStatsAssertion2 -> {
            symbolStatsAssertion2.averageRowSize(4.0d).distinctValuesCount(40.0d).lowValue(-10.0d).highValue(10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("x"), new DoubleLiteral("10.0"))).outputRowsCount(18.75d).symbolStats("x", symbolStatsAssertion3 -> {
            symbolStatsAssertion3.averageRowSize(4.0d).distinctValuesCount(1.0d).lowValue(10.0d).highValue(10.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("y"), new DoubleLiteral("10.0"))).outputRowsCount(0.0d).symbolStats("y", symbolStatsAssertion4 -> {
            symbolStatsAssertion4.averageRowSize(0.0d).distinctValuesCount(0.0d).emptyRange().nullsFraction(1.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("leftOpen"), new DoubleLiteral("0.0"))).outputRowsCount(225.0d).symbolStats("leftOpen", symbolStatsAssertion5 -> {
            symbolStatsAssertion5.averageRowSize(4.0d).distinctValuesCount(12.5d).lowValue(0.0d).highValue(15.0d).nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("rightOpen"), new DoubleLiteral("0.0"))).outputRowsCount(450.0d).symbolStats("rightOpen", symbolStatsAssertion6 -> {
            symbolStatsAssertion6.averageRowSize(4.0d).distinctValuesCount(25.0d).lowValue(0.0d).highValueUnknown().nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("unknownRange"), new DoubleLiteral("0.0"))).outputRowsCount(450.0d).symbolStats("unknownRange", symbolStatsAssertion7 -> {
            symbolStatsAssertion7.averageRowSize(4.0d).distinctValuesCount(25.0d).lowValue(0.0d).highValueUnknown().nullsFraction(0.0d);
        });
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new SymbolReference("emptyRange"), new DoubleLiteral("0.0"))).outputRowsCount(0.0d).symbolStats("emptyRange", equalTo(this.emptyRangeStats));
    }

    @Test
    public void symbolToSymbolEqualStats() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("u"), new SymbolReference("w"))).outputRowsCount(2.7d).symbolStats("u", equalTo(capNDV(zeroNullsFraction(this.uStats), 2.7d))).symbolStats("w", equalTo(capNDV(zeroNullsFraction(this.wStats), 2.7d))).symbolStats("z", equalTo(capNDV(this.zStats, 2.7d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("x"), new SymbolReference("y"))).outputRowsCount(9.375d).symbolStats("x", symbolStatsAssertion -> {
            symbolStatsAssertion.averageRowSize(4.0d).lowValue(0.0d).highValue(5.0d).distinctValuesCount(9.375d).nullsFraction(0.0d);
        }).symbolStats("y", symbolStatsAssertion2 -> {
            symbolStatsAssertion2.averageRowSize(4.0d).lowValue(0.0d).highValue(5.0d).distinctValuesCount(9.375d).nullsFraction(0.0d);
        }).symbolStats("z", equalTo(capNDV(this.zStats, 9.375d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("x"), new SymbolReference("w"))).outputRowsCount(16.875d).symbolStats("x", symbolStatsAssertion3 -> {
            symbolStatsAssertion3.averageRowSize(6.0d).lowValue(0.0d).highValue(10.0d).distinctValuesCount(16.875d).nullsFraction(0.0d);
        }).symbolStats("w", symbolStatsAssertion4 -> {
            symbolStatsAssertion4.averageRowSize(6.0d).lowValue(0.0d).highValue(10.0d).distinctValuesCount(16.875d).nullsFraction(0.0d);
        }).symbolStats("z", equalTo(capNDV(this.zStats, 16.875d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("x"), new SymbolReference("u"))).outputRowsCount(2.25d).symbolStats("x", symbolStatsAssertion5 -> {
            symbolStatsAssertion5.averageRowSize(6.0d).lowValue(0.0d).highValue(10.0d).distinctValuesCount(2.25d).nullsFraction(0.0d);
        }).symbolStats("u", symbolStatsAssertion6 -> {
            symbolStatsAssertion6.averageRowSize(6.0d).lowValue(0.0d).highValue(10.0d).distinctValuesCount(2.25d).nullsFraction(0.0d);
        }).symbolStats("z", equalTo(capNDV(this.zStats, 2.25d)));
    }

    @Test
    public void symbolToSymbolNotEqual() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("u"), new SymbolReference("w"))).outputRowsCount(807.3d).symbolStats("u", equalTo(capNDV(zeroNullsFraction(this.uStats), 807.3d))).symbolStats("w", equalTo(capNDV(zeroNullsFraction(this.wStats), 807.3d))).symbolStats("z", equalTo(capNDV(this.zStats, 807.3d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("x"), new SymbolReference("y"))).outputRowsCount(365.625d).symbolStats("x", equalTo(capNDV(zeroNullsFraction(this.xStats), 365.625d))).symbolStats("y", equalTo(capNDV(zeroNullsFraction(this.yStats), 365.625d))).symbolStats("z", equalTo(capNDV(this.zStats, 365.625d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("x"), new SymbolReference("w"))).outputRowsCount(658.125d).symbolStats("x", equalTo(capNDV(zeroNullsFraction(this.xStats), 658.125d))).symbolStats("w", equalTo(capNDV(zeroNullsFraction(this.wStats), 658.125d))).symbolStats("z", equalTo(capNDV(this.zStats, 658.125d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("x"), new SymbolReference("u"))).outputRowsCount(672.75d).symbolStats("x", equalTo(capNDV(zeroNullsFraction(this.xStats), 672.75d))).symbolStats("u", equalTo(capNDV(zeroNullsFraction(this.uStats), 672.75d))).symbolStats("z", equalTo(capNDV(this.zStats, 672.75d)));
    }

    @Test
    public void symbolToCastExpressionNotEqual() {
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("u"), new Cast(new SymbolReference("w"), "bigint"))).outputRowsCount(807.3d).symbolStats("u", equalTo(capNDV(zeroNullsFraction(this.uStats), 807.3d))).symbolStats("w", equalTo(capNDV(this.wStats, 807.3d))).symbolStats("z", equalTo(capNDV(this.zStats, 807.3d)));
        assertCalculate(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("u"), new Cast(new LongLiteral("10"), "bigint"))).outputRowsCount(897.0d).symbolStats("u", equalTo(capNDV(updateNDV(zeroNullsFraction(this.uStats), -1.0d), 897.0d))).symbolStats("z", equalTo(capNDV(this.zStats, 897.0d)));
    }

    private static void checkConsistent(StatsNormalizer statsNormalizer, String str, PlanNodeStatsEstimate planNodeStatsEstimate, Collection<Symbol> collection, TypeProvider typeProvider) {
        PlanNodeStatsEstimate normalize = statsNormalizer.normalize(planNodeStatsEstimate, collection, typeProvider);
        if (Objects.equals(planNodeStatsEstimate, normalize)) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        if (Double.compare(planNodeStatsEstimate.getOutputRowCount(), normalize.getOutputRowCount()) != 0) {
            arrayList.add(String.format("Output row count is %s, should be normalized to %s", Double.valueOf(planNodeStatsEstimate.getOutputRowCount()), Double.valueOf(normalize.getOutputRowCount())));
        }
        for (Symbol symbol : planNodeStatsEstimate.getSymbolsWithKnownStatistics()) {
            if (!Objects.equals(planNodeStatsEstimate.getSymbolStatistics(symbol), normalize.getSymbolStatistics(symbol))) {
                arrayList.add(String.format("Symbol stats for '%s' are \n\t\t\t\t\t%s, should be normalized to \n\t\t\t\t\t%s", symbol, planNodeStatsEstimate.getSymbolStatistics(symbol), normalize.getSymbolStatistics(symbol)));
            }
        }
        if (arrayList.isEmpty()) {
            arrayList.add(planNodeStatsEstimate.toString());
        }
        throw new IllegalStateException(String.format("Rule %s returned inconsistent stats: %s", str, arrayList.stream().collect(Collectors.joining("\n\t\t\t", "\n\t\t\t", ""))));
    }
}
