package org.apache.spark.sql.rapids;

import ai.rapids.cudf.CudaMemInfo;
import ai.rapids.cudf.Rmm;
import com.nvidia.spark.RapidsShuffleManager;
import com.nvidia.spark.rapids.DeviceMemoryEventHandler;
import com.nvidia.spark.rapids.RapidsBufferCatalog;
import com.nvidia.spark.rapids.RapidsConf;
import com.nvidia.spark.rapids.RapidsDeviceMemoryStore;
import com.nvidia.spark.rapids.RapidsDiskStore;
import com.nvidia.spark.rapids.RapidsHostMemoryStore;
import com.nvidia.spark.rapids.ShuffleBufferCatalog;
import com.nvidia.spark.rapids.ShuffleReceivedBufferCatalog;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv$;
import org.apache.spark.internal.Logging;
import org.apache.spark.util.Utils$;
import org.slf4j.Logger;
import scala.Function0;
import scala.Predef$;

/* compiled from: GpuShuffleEnv.scala */
/* loaded from: input_file:org/apache/spark/sql/rapids/GpuShuffleEnv$.class */
public final class GpuShuffleEnv$ implements Logging {
    public static GpuShuffleEnv$ MODULE$;
    private boolean isRapidsShuffleEnabled;
    private final String RAPIDS_SHUFFLE_CLASS;
    private boolean isRapidsShuffleManagerInitialized;
    private final RapidsBufferCatalog catalog;
    private ShuffleBufferCatalog shuffleCatalog;
    private ShuffleReceivedBufferCatalog shuffleReceivedBufferCatalog;
    private RapidsDeviceMemoryStore deviceStorage;
    private RapidsHostMemoryStore hostStorage;
    private RapidsDiskStore diskStorage;
    private DeviceMemoryEventHandler memoryEventHandler;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private volatile boolean bitmap$0;

    static {
        new GpuShuffleEnv$();
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    private String RAPIDS_SHUFFLE_CLASS() {
        return this.RAPIDS_SHUFFLE_CLASS;
    }

    private boolean isRapidsShuffleManagerInitialized() {
        return this.isRapidsShuffleManagerInitialized;
    }

    private void isRapidsShuffleManagerInitialized_$eq(boolean z) {
        this.isRapidsShuffleManagerInitialized = z;
    }

    private RapidsBufferCatalog catalog() {
        return this.catalog;
    }

    private ShuffleBufferCatalog shuffleCatalog() {
        return this.shuffleCatalog;
    }

    private void shuffleCatalog_$eq(ShuffleBufferCatalog shuffleBufferCatalog) {
        this.shuffleCatalog = shuffleBufferCatalog;
    }

    private ShuffleReceivedBufferCatalog shuffleReceivedBufferCatalog() {
        return this.shuffleReceivedBufferCatalog;
    }

    private void shuffleReceivedBufferCatalog_$eq(ShuffleReceivedBufferCatalog shuffleReceivedBufferCatalog) {
        this.shuffleReceivedBufferCatalog = shuffleReceivedBufferCatalog;
    }

    private RapidsDeviceMemoryStore deviceStorage() {
        return this.deviceStorage;
    }

    private void deviceStorage_$eq(RapidsDeviceMemoryStore rapidsDeviceMemoryStore) {
        this.deviceStorage = rapidsDeviceMemoryStore;
    }

    private RapidsHostMemoryStore hostStorage() {
        return this.hostStorage;
    }

    private void hostStorage_$eq(RapidsHostMemoryStore rapidsHostMemoryStore) {
        this.hostStorage = rapidsHostMemoryStore;
    }

    private RapidsDiskStore diskStorage() {
        return this.diskStorage;
    }

    private void diskStorage_$eq(RapidsDiskStore rapidsDiskStore) {
        this.diskStorage = rapidsDiskStore;
    }

    private DeviceMemoryEventHandler memoryEventHandler() {
        return this.memoryEventHandler;
    }

    private void memoryEventHandler_$eq(DeviceMemoryEventHandler deviceMemoryEventHandler) {
        this.memoryEventHandler = deviceMemoryEventHandler;
    }

    public boolean isRapidsShuffleConfigured(SparkConf sparkConf) {
        if (sparkConf.contains("spark.shuffle.manager")) {
            String str = sparkConf.get("spark.shuffle.manager");
            String RAPIDS_SHUFFLE_CLASS = RAPIDS_SHUFFLE_CLASS();
            if (str != null ? str.equals(RAPIDS_SHUFFLE_CLASS) : RAPIDS_SHUFFLE_CLASS == null) {
                return true;
            }
        }
        return false;
    }

    public void setRapidsShuffleManagerInitialized(boolean z, String str) {
        Predef$ predef$ = Predef$.MODULE$;
        String RAPIDS_SHUFFLE_CLASS = RAPIDS_SHUFFLE_CLASS();
        predef$.assert(str != null ? str.equals(RAPIDS_SHUFFLE_CLASS) : RAPIDS_SHUFFLE_CLASS == null);
        logInfo(() -> {
            return "RapidsShuffleManager is initialized";
        });
        isRapidsShuffleManagerInitialized_$eq(z);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.sql.rapids.GpuShuffleEnv$] */
    private boolean isRapidsShuffleEnabled$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                this.isRapidsShuffleEnabled = isRapidsShuffleManagerInitialized() && !SparkEnv$.MODULE$.get().blockManager().externalShuffleServiceEnabled();
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.isRapidsShuffleEnabled;
    }

    public boolean isRapidsShuffleEnabled() {
        return !this.bitmap$0 ? isRapidsShuffleEnabled$lzycompute() : this.isRapidsShuffleEnabled;
    }

    public void initStorage(RapidsConf rapidsConf, CudaMemInfo cudaMemInfo) {
        SparkConf conf = SparkEnv$.MODULE$.get().conf();
        if (isRapidsShuffleConfigured(conf)) {
            Predef$.MODULE$.assert(memoryEventHandler() == null);
            deviceStorage_$eq(new RapidsDeviceMemoryStore(catalog()));
            hostStorage_$eq(new RapidsHostMemoryStore(catalog(), rapidsConf.hostSpillStorageSize()));
            RapidsDiskBlockManager rapidsDiskBlockManager = new RapidsDiskBlockManager(conf);
            diskStorage_$eq(new RapidsDiskStore(catalog(), rapidsDiskBlockManager));
            deviceStorage().setSpillStore(hostStorage());
            hostStorage().setSpillStore(diskStorage());
            long rmmSpillAsyncStart = (long) (cudaMemInfo.total * rapidsConf.rmmSpillAsyncStart());
            long rmmSpillAsyncStop = (long) (cudaMemInfo.total * rapidsConf.rmmSpillAsyncStop());
            logInfo(() -> {
                return new StringBuilder(48).append("Installing GPU memory handler to start spill at ").append(new StringBuilder(13).append(Utils$.MODULE$.bytesToString(rmmSpillAsyncStart)).append(" and stop at ").toString()).append(String.valueOf(Utils$.MODULE$.bytesToString(rmmSpillAsyncStop))).toString();
            });
            memoryEventHandler_$eq(new DeviceMemoryEventHandler(deviceStorage(), rmmSpillAsyncStart, rmmSpillAsyncStop));
            Rmm.setEventHandler(memoryEventHandler());
            shuffleCatalog_$eq(new ShuffleBufferCatalog(catalog(), rapidsDiskBlockManager));
            shuffleReceivedBufferCatalog_$eq(new ShuffleReceivedBufferCatalog(catalog(), rapidsDiskBlockManager));
        }
    }

    public void closeStorage() {
        if (memoryEventHandler() != null) {
            memoryEventHandler_$eq(null);
        }
        if (deviceStorage() != null) {
            deviceStorage().close();
            deviceStorage_$eq(null);
        }
        if (hostStorage() != null) {
            hostStorage().close();
            hostStorage_$eq(null);
        }
        if (diskStorage() != null) {
            diskStorage().close();
            diskStorage_$eq(null);
        }
    }

    public ShuffleBufferCatalog getCatalog() {
        return shuffleCatalog();
    }

    public ShuffleReceivedBufferCatalog getReceivedCatalog() {
        return shuffleReceivedBufferCatalog();
    }

    public RapidsDeviceMemoryStore getDeviceStorage() {
        return deviceStorage();
    }

    private GpuShuffleEnv$() {
        MODULE$ = this;
        Logging.$init$(this);
        this.RAPIDS_SHUFFLE_CLASS = RapidsShuffleManager.class.getCanonicalName();
        this.isRapidsShuffleManagerInitialized = false;
        this.catalog = new RapidsBufferCatalog();
    }
}
