package ai.djl.onnxruntime.engine;

import ai.djl.BaseModel;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.Map;

/* loaded from: input_file:ai/djl/onnxruntime/engine/OrtModel.class */
public class OrtModel extends BaseModel {
    private OrtEnvironment env;

    /* JADX INFO: Access modifiers changed from: package-private */
    public OrtModel(String str, NDManager nDManager, OrtEnvironment ortEnvironment) {
        super(str);
        this.manager = nDManager;
        this.manager.setName("ortModel");
        this.env = ortEnvironment;
        this.dataType = DataType.FLOAT32;
    }

    public void load(Path path, String str, Map<String, ?> map) throws IOException, MalformedModelException {
        OrtSession createSession;
        setModelDir(path);
        if (this.block != null) {
            throw new UnsupportedOperationException("ONNX Runtime does not support dynamic blocks");
        }
        Path findModelFile = findModelFile(str);
        if (findModelFile == null) {
            findModelFile = findModelFile(this.modelDir.toFile().getName());
            if (findModelFile == null) {
                throw new FileNotFoundException(".onnx file not found in: " + path);
            }
        }
        try {
            if (this.manager.getDevice().isGpu()) {
                OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
                sessionOptions.addCUDA(this.manager.getDevice().getDeviceId());
                createSession = this.env.createSession(findModelFile.toString(), sessionOptions);
            } else {
                createSession = this.env.createSession(findModelFile.toString());
            }
            this.block = new OrtSymbolBlock(createSession, this.manager);
        } catch (OrtException e) {
            throw new MalformedModelException("ONNX Model cannot be loaded", e);
        }
    }

    private Path findModelFile(String str) {
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            Path path = this.modelDir;
            this.modelDir = this.modelDir.getParent();
            String name = path.toFile().getName();
            if (name.endsWith(".onnx")) {
                this.modelName = name.substring(0, name.length() - 5);
            } else {
                this.modelName = name;
            }
            return path;
        }
        if (str == null) {
            str = this.modelName;
        }
        Path resolve = this.modelDir.resolve(str);
        if (Files.notExists(resolve, new LinkOption[0]) || !Files.isRegularFile(resolve, new LinkOption[0])) {
            if (str.endsWith(".onnx")) {
                return null;
            }
            resolve = this.modelDir.resolve(str + ".onnx");
            if (Files.notExists(resolve, new LinkOption[0]) || !Files.isRegularFile(resolve, new LinkOption[0])) {
                return null;
            }
        }
        return resolve;
    }
}
