/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.onnxruntime.engine.OrtModel;
import ai.djl.onnxruntime.engine.OrtNDManager;
import ai.djl.training.GradientCollector;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

public final class OrtEngine
extends Engine {
    public static final String ENGINE_NAME = "OnnxRuntime";
    static final int RANK = 10;
    private Engine alternativeEngine;
    private OrtEnvironment env = OrtEnvironment.getEnvironment();

    private OrtEngine() {
    }

    static Engine newInstance() {
        return new OrtEngine();
    }

    public String getEngineName() {
        return ENGINE_NAME;
    }

    public int getRank() {
        return 10;
    }

    private Engine getAlternativeEngine() {
        Engine engine;
        if (Boolean.getBoolean("ai.djl.onnx.disable_alternative")) {
            return null;
        }
        if (this.alternativeEngine == null && (engine = Engine.getInstance()).getRank() < this.getRank()) {
            this.alternativeEngine = engine;
        }
        return this.alternativeEngine;
    }

    public String getVersion() {
        return "1.8.0";
    }

    public boolean hasCapability(String capability) {
        if ("MKL".equals(capability)) {
            return true;
        }
        if ("CUDA".equals(capability)) {
            try {
                OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
                sessionOptions.addCUDA();
                return true;
            }
            catch (OrtException e) {
                return false;
            }
        }
        return false;
    }

    public Model newModel(String name, Device device) {
        return new OrtModel(name, this.newBaseManager(device), this.env);
    }

    public SymbolBlock newSymbolBlock(NDManager manager) {
        throw new UnsupportedOperationException("ONNXRuntime does not support empty SymbolBlock");
    }

    public NDManager newBaseManager() {
        return this.newBaseManager(null);
    }

    public NDManager newBaseManager(Device device) {
        if (this.getAlternativeEngine() != null) {
            return this.alternativeEngine.newBaseManager(device);
        }
        return OrtNDManager.getSystemManager().newSubManager(device);
    }

    public GradientCollector newGradientCollector() {
        throw new UnsupportedOperationException("Not supported for ONNX Runtime");
    }

    public void setRandomSeed(int seed) {
        throw new UnsupportedOperationException("Not supported for ONNX Runtime");
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", ");
        sb.append(this.getEngineName()).append(':').append(this.getVersion()).append(", capabilities: [\n\tMKL,\n");
        if (this.hasCapability("CUDA")) {
            sb.append("\t").append("CUDA").append(",\n");
        }
        if (this.alternativeEngine != null) {
            sb.append("]\nAlternative engine: ").append(this.alternativeEngine.getEngineName());
        } else {
            sb.append("]\nNo alternative engine found");
        }
        return sb.toString();
    }
}

