package com.nvidia.spark.rapids;

import org.apache.spark.internal.Logging;
import org.apache.spark.sql.catalyst.plans.QueryPlan;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec;
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec;
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec;
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec;
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec;
import org.apache.spark.sql.internal.SQLConf$;
import org.slf4j.Logger;
import scala.Function0;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.mutable.ListBuffer;
import scala.math.Numeric$DoubleIsFractional$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

/* compiled from: CostBasedOptimizer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005-a\u0001B\u0004\t\u0001EAQa\n\u0001\u0005\u0002!BQA\u000b\u0001\u0005\u0002-BQ!\u0014\u0001\u0005\n9CQ!\u001c\u0001\u0005\n9DQ!\u001d\u0001\u0005\nIDQ!\u001e\u0001\u0005\nY\u0014!cQ8ti\n\u000b7/\u001a3PaRLW.\u001b>fe*\u0011\u0011BC\u0001\u0007e\u0006\u0004\u0018\u000eZ:\u000b\u0005-a\u0011!B:qCJ\\'BA\u0007\u000f\u0003\u0019qg/\u001b3jC*\tq\"A\u0002d_6\u001c\u0001a\u0005\u0003\u0001%aa\u0002CA\n\u0017\u001b\u0005!\"\"A\u000b\u0002\u000bM\u001c\u0017\r\\1\n\u0005]!\"AB!osJ+g\r\u0005\u0002\u001a55\t\u0001\"\u0003\u0002\u001c\u0011\tIq\n\u001d;j[&TXM\u001d\t\u0003;\u0015j\u0011A\b\u0006\u0003?\u0001\n\u0001\"\u001b8uKJt\u0017\r\u001c\u0006\u0003\u0017\u0005R!AI\u0012\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005!\u0013aA8sO&\u0011aE\b\u0002\b\u0019><w-\u001b8h\u0003\u0019a\u0014N\\5u}Q\t\u0011\u0006\u0005\u0002\u001a\u0001\u0005Aq\u000e\u001d;j[&TX\rF\u0002-w\u0001\u00032!L\u001b9\u001d\tq3G\u0004\u00020e5\t\u0001G\u0003\u00022!\u00051AH]8pizJ\u0011!F\u0005\u0003iQ\tq\u0001]1dW\u0006<W-\u0003\u00027o\t\u00191+Z9\u000b\u0005Q\"\u0002CA\r:\u0013\tQ\u0004B\u0001\u0007PaRLW.\u001b>bi&|g\u000eC\u0003=\u0005\u0001\u0007Q(\u0001\u0003d_:4\u0007CA\r?\u0013\ty\u0004B\u0001\u0006SCBLGm]\"p]\u001aDQ!\u0011\u0002A\u0002\t\u000bA\u0001\u001d7b]B\u0019\u0011dQ#\n\u0005\u0011C!!D*qCJ\\\u0007\u000b\\1o\u001b\u0016$\u0018\r\u0005\u0002G\u00176\tqI\u0003\u0002I\u0013\u0006IQ\r_3dkRLwN\u001c\u0006\u0003\u0015\u0002\n1a]9m\u0013\tauIA\u0005Ta\u0006\u00148\u000e\u00157b]\u0006\u0019\"/Z2veNLg/\u001a7z\u001fB$\u0018.\\5{KR9q*\u0016,\\;zC\u0007\u0003B\nQ%JK!!\u0015\u000b\u0003\rQ+\b\u000f\\33!\t\u00192+\u0003\u0002U)\t1Ai\\;cY\u0016DQ\u0001P\u0002A\u0002uBQaV\u0002A\u0002a\u000bAb\u00199v\u0007>\u001cH/T8eK2\u0004\"!G-\n\u0005iC!!C\"pgRlu\u000eZ3m\u0011\u0015a6\u00011\u0001Y\u000319\u0007/^\"pgRlu\u000eZ3m\u0011\u0015\t5\u00011\u0001C\u0011\u0015y6\u00011\u0001a\u00035y\u0007\u000f^5nSj\fG/[8ogB\u0019\u0011M\u001a\u001d\u000e\u0003\tT!a\u00193\u0002\u000f5,H/\u00192mK*\u0011Q\rF\u0001\u000bG>dG.Z2uS>t\u0017BA4c\u0005)a\u0015n\u001d;Ck\u001a4WM\u001d\u0005\u0006S\u000e\u0001\rA[\u0001\u000eM&t\u0017\r\\(qKJ\fGo\u001c:\u0011\u0005MY\u0017B\u00017\u0015\u0005\u001d\u0011un\u001c7fC:\f1\u0003\u001e:b]NLG/[8o)><\u0005/^\"pgR$2AU8q\u0011\u0015aD\u00011\u0001>\u0011\u0015\tE\u00011\u0001C\u0003M!(/\u00198tSRLwN\u001c+p\u0007B,8i\\:u)\r\u00116\u000f\u001e\u0005\u0006y\u0015\u0001\r!\u0010\u0005\u0006\u0003\u0016\u0001\rAQ\u0001\rSN,\u0005p\u00195b]\u001e,w\n\u001d\u000b\u0003U^DQ!\u0011\u0004A\u0002a\u0004$!\u001f?\u0011\u0007e\u0019%\u0010\u0005\u0002|y2\u0001A!C?x\u0003\u0003\u0005\tQ!\u0001\u007f\u0005\ryF%M\t\u0004\u007f\u0006\u0015\u0001cA\n\u0002\u0002%\u0019\u00111\u0001\u000b\u0003\u000f9{G\u000f[5oOB\u00191#a\u0002\n\u0007\u0005%ACA\u0002B]f\u0004")
/* loaded from: input_file:com/nvidia/spark/rapids/CostBasedOptimizer.class */
public class CostBasedOptimizer implements Optimizer, Logging {
    private transient Logger org$apache$spark$internal$Logging$$log_;

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    @Override // com.nvidia.spark.rapids.Optimizer
    public Seq<Optimization> optimize(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        CpuCostModel cpuCostModel = new CpuCostModel(rapidsConf);
        GpuCostModel gpuCostModel = new GpuCostModel(rapidsConf);
        ListBuffer<Optimization> listBuffer = new ListBuffer<>();
        recursivelyOptimize(rapidsConf, cpuCostModel, gpuCostModel, sparkPlanMeta, listBuffer, true);
        return listBuffer;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Tuple2<Object, Object> recursivelyOptimize(RapidsConf rapidsConf, CostModel costModel, CostModel costModel2, SparkPlanMeta<SparkPlan> sparkPlanMeta, ListBuffer<Optimization> listBuffer, boolean z) {
        Seq seq = (Seq) sparkPlanMeta.childPlans().map(sparkPlanMeta2 -> {
            return this.recursivelyOptimize(rapidsConf, costModel, costModel2, sparkPlanMeta2, listBuffer, false);
        }, Seq$.MODULE$.canBuildFrom());
        Tuple2 unzip = seq.unzip(Predef$.MODULE$.$conforms());
        if (unzip == null) {
            throw new MatchError(unzip);
        }
        Tuple2 tuple2 = new Tuple2((Seq) unzip._1(), (Seq) unzip._2());
        Seq seq2 = (Seq) tuple2._1();
        Seq seq3 = (Seq) tuple2._2();
        double cost = costModel.getCost(sparkPlanMeta);
        double cost2 = costModel2.getCost(sparkPlanMeta);
        double unboxToDouble = cost + BoxesRunTime.unboxToDouble(seq2.sum(Numeric$DoubleIsFractional$.MODULE$));
        DoubleRef create = DoubleRef.create(cost2 + BoxesRunTime.unboxToDouble(seq3.sum(Numeric$DoubleIsFractional$.MODULE$)));
        sparkPlanMeta.estimatedOutputRows_$eq(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta));
        if (sparkPlanMeta.childPlans().count(sparkPlanMeta3 -> {
            return BoxesRunTime.boxToBoolean($anonfun$recursivelyOptimize$2(sparkPlanMeta, sparkPlanMeta3));
        }) > 0) {
            if (sparkPlanMeta.canThisBeReplaced()) {
                double unboxToDouble2 = BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) sparkPlanMeta.childPlans().filter(sparkPlanMeta4 -> {
                    return BoxesRunTime.boxToBoolean($anonfun$recursivelyOptimize$3(sparkPlanMeta4));
                })).map(sparkPlanMeta5 -> {
                    return BoxesRunTime.boxToDouble(this.transitionToGpuCost(rapidsConf, sparkPlanMeta5));
                }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
                if (cost2 + unboxToDouble2 <= cost || isExchangeOp(sparkPlanMeta)) {
                    create.elem += unboxToDouble2;
                } else {
                    listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new AvoidTransition(sparkPlanMeta)}));
                    sparkPlanMeta.costPreventsRunningOnGpu();
                    create.elem = unboxToDouble;
                }
            } else {
                ((IterableLike) sparkPlanMeta.childPlans().zip(seq, Seq$.MODULE$.canBuildFrom())).foreach(tuple22 -> {
                    $anonfun$recursivelyOptimize$5(this, rapidsConf, listBuffer, unboxToDouble, create, tuple22);
                    return BoxedUnit.UNIT;
                });
                create.elem += BoxesRunTime.unboxToDouble(((TraversableOnce) ((TraversableLike) sparkPlanMeta.childPlans().filter(sparkPlanMeta6 -> {
                    return BoxesRunTime.boxToBoolean(sparkPlanMeta6.canThisBeReplaced());
                })).map(sparkPlanMeta7 -> {
                    return BoxesRunTime.boxToDouble(this.transitionToCpuCost(rapidsConf, sparkPlanMeta7));
                }, Seq$.MODULE$.canBuildFrom())).sum(Numeric$DoubleIsFractional$.MODULE$));
            }
        }
        if (z && sparkPlanMeta.canThisBeReplaced()) {
            create.elem += transitionToCpuCost(rapidsConf, sparkPlanMeta);
        }
        if (create.elem > unboxToDouble && sparkPlanMeta.canThisBeReplaced() && !isExchangeOp(sparkPlanMeta)) {
            listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new ReplaceSection(sparkPlanMeta, unboxToDouble, create.elem)}));
            sparkPlanMeta.recursiveCostPreventsRunningOnGpu();
            create.elem = unboxToDouble;
        }
        if (!sparkPlanMeta.canThisBeReplaced() || isExchangeOp(sparkPlanMeta)) {
            create.elem = unboxToDouble;
        }
        return new Tuple2.mcDD.sp(unboxToDouble, create.elem);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double transitionToGpuCost(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        double unboxToDouble = BoxesRunTime.unboxToDouble(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta).map(bigInt -> {
            return BoxesRunTime.boxToDouble(bigInt.toDouble());
        }).getOrElse(() -> {
            return rapidsConf.defaultRowCount();
        }));
        long estimateGpuMemory = GpuBatchUtils$.MODULE$.estimateGpuMemory(((QueryPlan) sparkPlanMeta.wrapped()).schema(), (long) unboxToDouble);
        return (BoxesRunTime.unboxToDouble(rapidsConf.getGpuOperatorCost("GpuRowToColumnarExec").getOrElse(() -> {
            return 0.0d;
        })) * unboxToDouble) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.cpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.gpuWriteMemorySpeed());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double transitionToCpuCost(RapidsConf rapidsConf, SparkPlanMeta<SparkPlan> sparkPlanMeta) {
        double unboxToDouble = BoxesRunTime.unboxToDouble(RowCountPlanVisitor$.MODULE$.visit(sparkPlanMeta).map(bigInt -> {
            return BoxesRunTime.boxToDouble(bigInt.toDouble());
        }).getOrElse(() -> {
            return rapidsConf.defaultRowCount();
        }));
        long estimateGpuMemory = GpuBatchUtils$.MODULE$.estimateGpuMemory(((QueryPlan) sparkPlanMeta.wrapped()).schema(), (long) unboxToDouble);
        return (BoxesRunTime.unboxToDouble(rapidsConf.getGpuOperatorCost("GpuColumnarToRowExec").getOrElse(() -> {
            return 0.0d;
        })) * unboxToDouble) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.gpuReadMemorySpeed()) + MemoryCostHelper$.MODULE$.calculateCost(estimateGpuMemory, rapidsConf.cpuWriteMemorySpeed());
    }

    private boolean isExchangeOp(SparkPlanMeta<?> sparkPlanMeta) {
        if (SQLConf$.MODULE$.get().adaptiveExecutionEnabled()) {
            INPUT wrapped = sparkPlanMeta.wrapped();
            if (wrapped instanceof CustomShuffleReaderExec ? true : wrapped instanceof ShuffledHashJoinExec ? true : wrapped instanceof BroadcastHashJoinExec ? true : wrapped instanceof BroadcastExchangeExec ? true : wrapped instanceof BroadcastNestedLoopJoinExec) {
                return true;
            }
        }
        return false;
    }

    public static final /* synthetic */ boolean $anonfun$recursivelyOptimize$2(SparkPlanMeta sparkPlanMeta, SparkPlanMeta sparkPlanMeta2) {
        return sparkPlanMeta2.canThisBeReplaced() != sparkPlanMeta.canThisBeReplaced();
    }

    public static final /* synthetic */ boolean $anonfun$recursivelyOptimize$3(SparkPlanMeta sparkPlanMeta) {
        return !sparkPlanMeta.canThisBeReplaced();
    }

    public static final /* synthetic */ void $anonfun$recursivelyOptimize$5(CostBasedOptimizer costBasedOptimizer, RapidsConf rapidsConf, ListBuffer listBuffer, double d, DoubleRef doubleRef, Tuple2 tuple2) {
        BoxedUnit boxedUnit;
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        SparkPlanMeta<SparkPlan> sparkPlanMeta = (SparkPlanMeta) tuple2._1();
        Tuple2 tuple22 = (Tuple2) tuple2._2();
        if (tuple22 == null) {
            throw new MatchError(tuple22);
        }
        Tuple2.mcDD.sp spVar = new Tuple2.mcDD.sp(tuple22._1$mcD$sp(), tuple22._2$mcD$sp());
        double _1$mcD$sp = spVar._1$mcD$sp();
        double _2$mcD$sp = spVar._2$mcD$sp() + costBasedOptimizer.transitionToCpuCost(rapidsConf, sparkPlanMeta);
        if (!sparkPlanMeta.canThisBeReplaced() || costBasedOptimizer.isExchangeOp(sparkPlanMeta) || _2$mcD$sp <= _1$mcD$sp) {
            boxedUnit = BoxedUnit.UNIT;
        } else {
            listBuffer.append(Predef$.MODULE$.wrapRefArray(new Optimization[]{new ReplaceSection(sparkPlanMeta, d, doubleRef.elem)}));
            sparkPlanMeta.recursiveCostPreventsRunningOnGpu();
            boxedUnit = BoxedUnit.UNIT;
        }
    }

    public CostBasedOptimizer() {
        Logging.$init$(this);
    }
}
