package ml.dmlc.xgboost4j.scala.spark.rapids;

import ai.rapids.cudf.Cuda;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.internal.Logging;
import org.apache.spark.resource.ResourceInformation;
import org.slf4j.Logger;
import scala.Function0;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.ArrayBuffer$;
import scala.collection.mutable.ArrayOps;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.util.control.NonFatal$;

/* compiled from: GpuDeviceManager.scala */
/* loaded from: input_file:ml/dmlc/xgboost4j/scala/spark/rapids/GpuDeviceManager$.class */
public final class GpuDeviceManager$ implements Logging {
    public static GpuDeviceManager$ MODULE$;
    private transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        new GpuDeviceManager$();
    }

    public String logName() {
        return Logging.logName$(this);
    }

    public Logger log() {
        return Logging.log$(this);
    }

    public void logInfo(Function0<String> function0) {
        Logging.logInfo$(this, function0);
    }

    public void logDebug(Function0<String> function0) {
        Logging.logDebug$(this, function0);
    }

    public void logTrace(Function0<String> function0) {
        Logging.logTrace$(this, function0);
    }

    public void logWarning(Function0<String> function0) {
        Logging.logWarning$(this, function0);
    }

    public void logError(Function0<String> function0) {
        Logging.logError$(this, function0);
    }

    public void logInfo(Function0<String> function0, Throwable th) {
        Logging.logInfo$(this, function0, th);
    }

    public void logDebug(Function0<String> function0, Throwable th) {
        Logging.logDebug$(this, function0, th);
    }

    public void logTrace(Function0<String> function0, Throwable th) {
        Logging.logTrace$(this, function0, th);
    }

    public void logWarning(Function0<String> function0, Throwable th) {
        Logging.logWarning$(this, function0, th);
    }

    public void logError(Function0<String> function0, Throwable th) {
        Logging.logError$(this, function0, th);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$(this);
    }

    public void initializeLogIfNecessary(boolean z) {
        Logging.initializeLogIfNecessary$(this, z);
    }

    public boolean initializeLogIfNecessary(boolean z, boolean z2) {
        return Logging.initializeLogIfNecessary$(this, z, z2);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$(this);
    }

    public void initializeForcefully(boolean z, boolean z2) {
        Logging.initializeForcefully$(this, z, z2);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger logger) {
        this.org$apache$spark$internal$Logging$$log_ = logger;
    }

    public int getGpuId(boolean z) {
        if (z) {
            return 0;
        }
        return BoxesRunTime.unboxToInt(initializeGpu(getResourcesFromTaskContext()).getOrElse(() -> {
            return MODULE$.findGpuAndAcquire();
        }));
    }

    private Option<Object> initializeGpu(Map<String, ResourceInformation> map) {
        return getGPUAddrFromResources(map).map(i -> {
            return MODULE$.setGpuDeviceAndAcquire(i);
        });
    }

    private Map<String, ResourceInformation> getResourcesFromTaskContext() {
        TaskContext taskContext = TaskContext$.MODULE$.get();
        return taskContext == null ? Predef$.MODULE$.Map().empty() : taskContext.resources();
    }

    private Option<Object> getGPUAddrFromResources(Map<String, ResourceInformation> map) {
        if (!map.contains("gpu")) {
            return None$.MODULE$;
        }
        String[] addresses = ((ResourceInformation) map.apply("gpu")).addresses();
        if (new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(addresses)).size() > 1) {
            throw new IllegalArgumentException("Spark GPU Plugin only supports 1 gpu per executor");
        }
        return new Some(BoxesRunTime.boxToInteger(new StringOps(Predef$.MODULE$.augmentString((String) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(addresses)).head())).toInt()));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int findGpuAndAcquire() {
        logInfo(() -> {
            return "fall back to exclusive mode";
        });
        ArrayBuffer $plus$plus$eq = ArrayBuffer$.MODULE$.empty().$plus$plus$eq(RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), Cuda.getDeviceCount() - 1));
        for (int i = 2; i > 0; i--) {
            Option find = $plus$plus$eq.find(i2 -> {
                return MODULE$.tryToSetGpuDeviceAndAcquire(i2);
            });
            if (find.isDefined()) {
                return BoxesRunTime.unboxToInt(find.get());
            }
        }
        throw new IllegalStateException("Could not find a single GPU to use");
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean tryToSetGpuDeviceAndAcquire(int i) {
        try {
            setGpuDeviceAndAcquire(i);
            return true;
        } catch (Throwable th) {
            if (NonFatal$.MODULE$.unapply(th).isEmpty()) {
                throw th;
            }
            return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int setGpuDeviceAndAcquire(int i) {
        logDebug(() -> {
            return new StringBuilder(30).append("Initializing GPU device ID to ").append(i).toString();
        });
        Cuda.setDevice(i);
        Cuda.freeZero();
        return i;
    }

    private GpuDeviceManager$() {
        MODULE$ = this;
        Logging.$init$(this);
    }
}
