/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.cost.PlanNodeStatsAssertion;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SemiJoinStatsCalculator;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.type.UnknownType;
import org.junit.jupiter.api.Test;

public class TestSemiJoinStatsCalculator {
    private final SymbolStatsEstimate uStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0).setDistinctValuesCount(300.0).setLowValue(0.0).setHighValue(20.0).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate wStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0).setDistinctValuesCount(30.0).setLowValue(0.0).setHighValue(20.0).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate xStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(40.0).setLowValue(-10.0).setHighValue(10.0).setNullsFraction(0.25).build();
    private final SymbolStatsEstimate yStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(20.0).setLowValue(0.0).setHighValue(5.0).setNullsFraction(0.5).build();
    private final SymbolStatsEstimate zStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(5.0).setLowValue(-100.0).setHighValue(100.0).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate leftOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(15.0).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate rightOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(-15.0).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate unknownRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1).build();
    private final SymbolStatsEstimate emptyRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(0.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(Double.NaN).build();
    private final SymbolStatsEstimate fractionalNdvStats = SymbolStatsEstimate.builder().setAverageRowSize(Double.NaN).setDistinctValuesCount(0.1).setNullsFraction(0.0).build();
    private final Symbol u = new Symbol((Type)UnknownType.UNKNOWN, "u");
    private final Symbol w = new Symbol((Type)UnknownType.UNKNOWN, "w");
    private final Symbol x = new Symbol((Type)UnknownType.UNKNOWN, "x");
    private final Symbol y = new Symbol((Type)UnknownType.UNKNOWN, "y");
    private final Symbol z = new Symbol((Type)UnknownType.UNKNOWN, "z");
    private final Symbol leftOpen = new Symbol((Type)UnknownType.UNKNOWN, "leftOpen");
    private final Symbol rightOpen = new Symbol((Type)UnknownType.UNKNOWN, "rightOpen");
    private final Symbol unknownRange = new Symbol((Type)UnknownType.UNKNOWN, "unknownRange");
    private final Symbol emptyRange = new Symbol((Type)UnknownType.UNKNOWN, "emptyRange");
    private final Symbol unknown = new Symbol((Type)UnknownType.UNKNOWN, "unknown");
    private final Symbol fractionalNdv = new Symbol((Type)UnknownType.UNKNOWN, "fractionalNdv");
    private final PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder().addSymbolStatistics(this.u, this.uStats).addSymbolStatistics(this.w, this.wStats).addSymbolStatistics(this.x, this.xStats).addSymbolStatistics(this.y, this.yStats).addSymbolStatistics(this.z, this.zStats).addSymbolStatistics(this.leftOpen, this.leftOpenStats).addSymbolStatistics(this.rightOpen, this.rightOpenStats).addSymbolStatistics(this.unknownRange, this.unknownRangeStats).addSymbolStatistics(this.emptyRange, this.emptyRangeStats).addSymbolStatistics(this.unknown, SymbolStatsEstimate.unknown()).addSymbolStatistics(this.fractionalNdv, this.fractionalNdvStats).setOutputRowCount(1000.0).build();

    @Test
    public void testSemiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.w)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.wStats.getDistinctValuesCount())).symbolStats(this.w, stats -> stats.isEqualTo(this.wStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * (this.wStats.getDistinctValuesCount() / this.xStats.getDistinctValuesCount()));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.u)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.xStats.getDistinctValuesCount())).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction());
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.unknown, (Symbol)this.u)).symbolStats(this.unknown, stats -> stats.nullsFraction(0.0).distinctValuesCountUnknown().unknownRange()).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.unknown)).symbolStats(this.x, stats -> stats.nullsFraction(0.0).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown()).symbolStatsUnknown(this.unknown).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.emptyRange, (Symbol)this.emptyRange)).outputRowsCount(0.0);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.fractionalNdv, (Symbol)this.fractionalNdv)).outputRowsCount(1000.0).symbolStats(this.fractionalNdv, stats -> stats.nullsFraction(0.0).distinctValuesCount(0.1));
    }

    @Test
    public void testAntiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.u, (Symbol)this.x)).symbolStats(this.u, stats -> stats.lowValue(this.uStats.getLowValue()).highValue(this.uStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.uStats.getDistinctValuesCount() - this.xStats.getDistinctValuesCount())).symbolStats(this.x, stats -> stats.isEqualTo(this.xStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.uStats.getValuesFraction() * (1.0 - this.xStats.getDistinctValuesCount() / this.uStats.getDistinctValuesCount()));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.u)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.xStats.getDistinctValuesCount() * 0.5)).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * 0.5);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.unknown, (Symbol)this.u)).symbolStats(this.unknown, stats -> stats.nullsFraction(0.0).distinctValuesCountUnknown().unknownRange()).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.unknown)).symbolStats(this.x, stats -> stats.nullsFraction(0.0).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown()).symbolStatsUnknown(this.unknown).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.emptyRange, (Symbol)this.emptyRange)).outputRowsCount(0.0);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.fractionalNdv, (Symbol)this.fractionalNdv)).outputRowsCount(500.0).symbolStats(this.fractionalNdv, stats -> stats.nullsFraction(0.0).distinctValuesCount(0.05));
    }
}

