package org.apache.wayang.core.optimizer.costs;

import java.util.Arrays;
import java.util.function.ToDoubleBiFunction;
import java.util.function.ToLongBiFunction;
import java.util.stream.LongStream;
import org.apache.commons.lang3.Validate;
import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate;
import org.apache.wayang.core.plan.wayangplan.ExecutionOperator;

/* loaded from: input_file:org/apache/wayang/core/optimizer/costs/IntervalLoadEstimator.class */
public class IntervalLoadEstimator extends LoadEstimator {
    private final double correctnessProbablity;
    private final int numInputs;
    private final int numOutputs;
    private final ToLongBiFunction<long[], long[]> lowerBoundEstimator;
    private final ToLongBiFunction<long[], long[]> upperBoundEstimator;

    public IntervalLoadEstimator(int i, int i2, double d, ToLongBiFunction<long[], long[]> toLongBiFunction, ToLongBiFunction<long[], long[]> toLongBiFunction2) {
        this(i, i2, d, null, toLongBiFunction, toLongBiFunction2);
    }

    public IntervalLoadEstimator(int i, int i2, double d, CardinalityEstimate cardinalityEstimate, ToLongBiFunction<long[], long[]> toLongBiFunction, ToLongBiFunction<long[], long[]> toLongBiFunction2) {
        super(cardinalityEstimate);
        this.numInputs = i;
        this.numOutputs = i2;
        this.correctnessProbablity = d;
        this.lowerBoundEstimator = toLongBiFunction;
        this.upperBoundEstimator = toLongBiFunction2;
    }

    public static ToLongBiFunction<long[], long[]> rounded(ToDoubleBiFunction<long[], long[]> toDoubleBiFunction) {
        return (jArr, jArr2) -> {
            return Math.round(toDoubleBiFunction.applyAsDouble(jArr, jArr2));
        };
    }

    public static <T extends ExecutionOperator> LoadEstimator createIOLinearEstimator(T t, long j, long j2, double d, CardinalityEstimate cardinalityEstimate) {
        return new IntervalLoadEstimator(t == null ? -1 : t.getNumInputs(), t == null ? -1 : t.getNumOutputs(), d, cardinalityEstimate, (jArr, jArr2) -> {
            return j * LongStream.concat(Arrays.stream(jArr), Arrays.stream(jArr2)).sum();
        }, (jArr3, jArr4) -> {
            return j2 * LongStream.concat(Arrays.stream(jArr3), Arrays.stream(jArr4)).sum();
        });
    }

    @Override // org.apache.wayang.core.optimizer.costs.LoadEstimator
    public LoadEstimate calculate(EstimationContext estimationContext) {
        CardinalityEstimate[] inputCardinalities = estimationContext.getInputCardinalities();
        CardinalityEstimate[] outputCardinalities = estimationContext.getOutputCardinalities();
        Validate.isTrue(inputCardinalities.length >= this.numInputs || this.numInputs == -1, "Received %d input estimates, require %d.", new Object[]{Integer.valueOf(inputCardinalities.length), Integer.valueOf(this.numInputs)});
        Validate.isTrue(outputCardinalities.length == this.numOutputs || this.numOutputs == -1, "Received %d output estimates, require %d.", new Object[]{Integer.valueOf(outputCardinalities.length), Integer.valueOf(this.numOutputs)});
        long[][] enumerateCombinations = enumerateCombinations(inputCardinalities);
        long[][] enumerateCombinations2 = enumerateCombinations(outputCardinalities);
        long j = -1;
        long j2 = -1;
        for (int i = 0; i < enumerateCombinations.length; i++) {
            for (int i2 = 0; i2 < enumerateCombinations2.length; i2++) {
                long max = Math.max(this.lowerBoundEstimator.applyAsLong(enumerateCombinations[i], enumerateCombinations2[i2]), 0L);
                if (j == -1 || max < j) {
                    j = max;
                }
                if (j2 == -1 || max > j2) {
                    j2 = max;
                }
                long max2 = Math.max(this.upperBoundEstimator.applyAsLong(enumerateCombinations[i], enumerateCombinations2[i2]), 0L);
                if (j == -1 || max2 < j) {
                    j = max2;
                }
                if (j2 == -1 || max2 > j2) {
                    j2 = max2;
                }
            }
        }
        return new LoadEstimate(j, j2, calculateJointProbability(inputCardinalities, outputCardinalities) * this.correctnessProbablity);
    }
}
