package ai.djl.onnxruntime.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDManager;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtLoggingLevel;
import ai.onnxruntime.OrtSession;

/* loaded from: input_file:ai/djl/onnxruntime/engine/OrtEngine.class */
public final class OrtEngine extends Engine {
    public static final String ENGINE_NAME = "OnnxRuntime";
    static final int RANK = 10;
    private OrtEnvironment env;
    private Engine alternativeEngine;
    private boolean initialized;

    private OrtEngine() {
        OrtEnvironment.ThreadingOptions threadingOptions = new OrtEnvironment.ThreadingOptions();
        try {
            Integer integer = Integer.getInteger("ai.djl.onnxruntime.num_interop_threads");
            Integer integer2 = Integer.getInteger("ai.djl.onnxruntime.num_threads");
            if (integer != null) {
                threadingOptions.setGlobalInterOpNumThreads(integer.intValue());
            }
            if (integer2 != null) {
                threadingOptions.setGlobalIntraOpNumThreads(integer2.intValue());
            }
            this.env = OrtEnvironment.getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, "ort-java", threadingOptions);
        } catch (OrtException e) {
            threadingOptions.close();
            throw new AssertionError("Failed to config OrtEnvironment", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Engine newInstance() {
        return new OrtEngine();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public OrtEnvironment getEnv() {
        return this.env;
    }

    public Engine getAlternativeEngine() {
        if (!this.initialized && !Boolean.getBoolean("ai.djl.onnx.disable_alternative")) {
            Engine engine = Engine.getInstance();
            if (engine.getRank() < getRank()) {
                this.alternativeEngine = engine;
            }
            this.initialized = true;
        }
        return this.alternativeEngine;
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public int getRank() {
        return RANK;
    }

    public String getVersion() {
        return "1.16.0";
    }

    public boolean hasCapability(String str) {
        if ("MKL".equals(str)) {
            return true;
        }
        if (!"CUDA".equals(str)) {
            return false;
        }
        try {
            OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
            try {
                sessionOptions.addCUDA();
                sessionOptions.close();
                return true;
            } finally {
            }
        } catch (OrtException e) {
            return false;
        }
    }

    public Model newModel(String str, Device device) {
        return new OrtModel(str, newBaseManager(device), this.env);
    }

    public NDManager newBaseManager() {
        return newBaseManager(null);
    }

    public NDManager newBaseManager(Device device) {
        return OrtNDManager.getSystemManager().mo2newSubManager(device);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(getEngineName()).append(':').append(getVersion()).append(", ");
        sb.append(getEngineName()).append(':').append(getVersion()).append(", capabilities: [\n\tMKL,\n");
        if (hasCapability("CUDA")) {
            sb.append("\t").append("CUDA").append(",\n");
        }
        return sb.toString();
    }
}
