package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.mxnet.jna.LibUtils;
import ai.djl.ndarray.NDManager;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.util.RandomUtils;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/mxnet/engine/MxEngine.class */
public final class MxEngine extends Engine {
    private static final Logger logger = LoggerFactory.getLogger(MxEngine.class);
    public static final String ENGINE_NAME = "MXNet";

    private MxEngine() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Engine newInstance() {
        try {
            JnaUtils.getAllOpNames();
            JnaUtils.setNumpyMode(JnaUtils.NumpyMode.GLOBAL_ON);
            Runtime.getRuntime().addShutdownHook(new Thread(JnaUtils::waitAll));
            return new MxEngine();
        } catch (Throwable th) {
            logger.warn("Failed to load MXNet native library", th);
            return null;
        }
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public String getVersion() {
        int version = JnaUtils.getVersion();
        int i = version / 10000;
        return i + "." + ((version / 100) - (i * 100)) + '.' + (version % 100);
    }

    public boolean hasCapability(String str) {
        return JnaUtils.getFeatures().contains(str);
    }

    public Model newModel(String str, Device device) {
        return new MxModel(str, device);
    }

    public NDManager newBaseManager() {
        return MxNDManager.getSystemManager().mo10newSubManager();
    }

    public NDManager newBaseManager(Device device) {
        return MxNDManager.getSystemManager().mo9newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        return new MxGradientCollector();
    }

    public ParameterServer newParameterServer(Optimizer optimizer) {
        return Boolean.getBoolean("ai.djl.use_local_parameter_server") ? new LocalParameterServer(optimizer) : new MxParameterServer(optimizer);
    }

    public void setRandomSeed(int i) {
        JnaUtils.randomSeed(i);
        RandomUtils.RANDOM.setSeed(i);
    }

    public void debugEnvironment() {
        super.debugEnvironment();
        logger.info("MXNet Library: {}", LibUtils.getLibName());
        logger.info("MXNet Features: {}", String.join(", ", JnaUtils.getFeatures()));
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Name: ").append(getEngineName()).append(", version: ").append(getVersion()).append(", capabilities: [\n");
        Iterator<String> it = JnaUtils.getFeatures().iterator();
        while (it.hasNext()) {
            sb.append("\t").append(it.next()).append(",\n");
        }
        sb.append(']');
        return sb.toString();
    }
}
