package ai.djl.pytorch.jni;

import ai.djl.util.Platform;
import ai.djl.util.Utils;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.nio.file.attribute.FileAttribute;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.zip.GZIPInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/pytorch/jni/LibUtils.class */
public final class LibUtils {
    private static final Logger logger = LoggerFactory.getLogger(LibUtils.class);
    private static final String NATIVE_LIB_NAME = System.mapLibraryName("torch");
    private static final String JNI_LIB_NAME = System.mapLibraryName("djl_torch");
    private static final Pattern VERSION_PATTERN = Pattern.compile("(\\d+\\.\\d+\\.\\d+(-[a-z]+)?)(-SNAPSHOT)?(-\\d+)?");
    private static LibTorch libTorch;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/pytorch/jni/LibUtils$LibTorch.class */
    public static final class LibTorch {
        Path dir;
        String version;
        String apiVersion;
        String flavor;
        String classifier;

        LibTorch(Path path) {
            Platform detectPlatform = Platform.detectPlatform("pytorch");
            this.dir = path;
            this.apiVersion = detectPlatform.getApiVersion();
            this.classifier = detectPlatform.getClassifier();
            this.version = System.getenv("PYTORCH_VERSION");
            if (this.version == null) {
                this.version = System.getProperty("PYTORCH_VERSION");
                if (this.version == null) {
                    this.version = detectPlatform.getVersion();
                }
            }
            this.flavor = System.getenv("PYTORCH_FLAVOR");
            if (this.flavor == null) {
                this.flavor = System.getProperty("PYTORCH_FLAVOR");
                if (this.flavor == null) {
                    this.flavor = "cpu-precxx11";
                }
            }
        }

        LibTorch(Path path, Platform platform, String str) {
            this.dir = path;
            this.version = platform.getVersion();
            this.apiVersion = platform.getApiVersion();
            this.classifier = platform.getClassifier();
            this.flavor = str;
        }
    }

    private LibUtils() {
    }

    public static synchronized void loadLibrary() {
        if ("http://www.android.com/".equals(System.getProperty("java.vendor.url"))) {
            System.loadLibrary("djl_torch");
            return;
        }
        libTorch = getLibTorch();
        loadLibTorch(libTorch);
        loadNativeLibrary(findJniLibrary(libTorch).toAbsolutePath().toString());
    }

    private static LibTorch getLibTorch() {
        LibTorch findOverrideLibrary = findOverrideLibrary();
        return findOverrideLibrary != null ? findOverrideLibrary : findNativeLibrary();
    }

    public static String getVersion() {
        Matcher matcher = VERSION_PATTERN.matcher(libTorch.version);
        return matcher.matches() ? matcher.group(1) : libTorch.version;
    }

    public static String getLibtorchPath() {
        return libTorch.dir.toString();
    }

    private static void loadLibTorch(LibTorch libTorch2) {
        Path absolutePath = libTorch2.dir.toAbsolutePath();
        if ("1.8.1".equals(getVersion()) && System.getProperty("os.name").startsWith("Mac")) {
            return;
        }
        List asList = Arrays.asList(System.mapLibraryName("fbgemm"), System.mapLibraryName("caffe2_nvrtc"), System.mapLibraryName("torch_cpu"), System.mapLibraryName("c10_cuda"), System.mapLibraryName("torch_cuda_cpp"), System.mapLibraryName("torch_cuda_cu"), System.mapLibraryName("torch_cuda"), System.mapLibraryName("torch"));
        HashSet hashSet = new HashSet(asList);
        try {
            Stream<Path> walk = Files.walk(absolutePath, new FileVisitOption[0]);
            try {
                walk.filter(path -> {
                    String path = path.getFileName().toString();
                    return (hashSet.contains(path) || !Files.isRegularFile(path, new LinkOption[0]) || path.endsWith(JNI_LIB_NAME) || path.contains("torch_") || path.contains("caffe2_") || path.startsWith("cudnn")) ? false : true;
                }).map((v0) -> {
                    return v0.toString();
                }).forEach(LibUtils::loadNativeLibrary);
                if (Files.exists(absolutePath.resolve("cudnn64_8.dll"), new LinkOption[0])) {
                    loadNativeLibrary(absolutePath.resolve("cudnn64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_ops_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_ops_train64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_cnn_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_cnn_train64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_adv_infer64_8.dll").toString());
                    loadNativeLibrary(absolutePath.resolve("cudnn_adv_train64_8.dll").toString());
                } else if (Files.exists(absolutePath.resolve("cudnn64_7.dll"), new LinkOption[0])) {
                    loadNativeLibrary(absolutePath.resolve("cudnn64_7.dll").toString());
                }
                Iterator it = asList.iterator();
                while (it.hasNext()) {
                    Path resolve = absolutePath.resolve((String) it.next());
                    if (Files.exists(resolve, new LinkOption[0])) {
                        loadNativeLibrary(resolve.toString());
                    }
                }
                if (walk != null) {
                    walk.close();
                }
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException("Folder not exist! " + absolutePath, e);
        }
    }

    private static LibTorch findOverrideLibrary() {
        LibTorch findLibraryInPath;
        String str = System.getenv("PYTORCH_LIBRARY_PATH");
        if (str != null && (findLibraryInPath = findLibraryInPath(str)) != null) {
            return findLibraryInPath;
        }
        String property = System.getProperty("java.library.path");
        if (property != null) {
            return findLibraryInPath(property);
        }
        return null;
    }

    private static LibTorch findLibraryInPath(String str) {
        for (String str2 : str.split(File.pathSeparator)) {
            File file = new File(str2);
            if (file.exists()) {
                if (file.isFile() && NATIVE_LIB_NAME.equals(file.getName())) {
                    return new LibTorch(file.getParentFile().toPath().toAbsolutePath());
                }
                File file2 = new File(str2, NATIVE_LIB_NAME);
                if (file2.exists() && file2.isFile()) {
                    return new LibTorch(file.toPath().toAbsolutePath());
                }
            }
        }
        return null;
    }

    private static Path findJniLibrary(LibTorch libTorch2) {
        String str = libTorch2.classifier;
        String str2 = libTorch2.version;
        String str3 = libTorch2.apiVersion;
        String str4 = libTorch2.flavor;
        Path resolve = Utils.getEngineCacheDir("pytorch").resolve(str2 + '-' + str4 + '-' + str);
        Path resolve2 = resolve.resolve(str3 + '-' + JNI_LIB_NAME);
        if (Files.exists(resolve2, new LinkOption[0])) {
            return resolve2;
        }
        Matcher matcher = VERSION_PATTERN.matcher(str2);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Unexpected version: " + str2);
        }
        String group = matcher.group(1);
        try {
            InputStream resourceAsStream = LibUtils.class.getResourceAsStream("/jnilib/pytorch.properties");
            String str5 = null;
            if (resourceAsStream != null) {
                try {
                    Properties properties = new Properties();
                    properties.load(resourceAsStream);
                    str5 = properties.getProperty("jni_version");
                    if (str5 == null) {
                        throw new AssertionError("No PyTorch jni version found.");
                    }
                } finally {
                }
            }
            if (str5 == null) {
                downloadJniLib(resolve, resolve2, str3, group, str, str4);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                return resolve2;
            }
            if (!str5.startsWith(group + '-' + str3)) {
                logger.warn("Found mismatch PyTorch jni: {}", str5);
                downloadJniLib(resolve, resolve2, str3, group, str, str4);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                return resolve2;
            }
            if (resourceAsStream != null) {
                resourceAsStream.close();
            }
            String str6 = "/jnilib/" + str + '/' + str4 + '/' + JNI_LIB_NAME;
            logger.info("Extracting {} to cache ...", str6);
            try {
                try {
                    InputStream resourceAsStream2 = LibUtils.class.getResourceAsStream(str6);
                    try {
                        if (resourceAsStream2 == null) {
                            throw new AssertionError("PyTorch jni not found: " + str6);
                        }
                        Path createTempFile = Files.createTempFile(resolve, "jni", "tmp", new FileAttribute[0]);
                        Files.copy(resourceAsStream2, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                        Utils.moveQuietly(createTempFile, resolve2);
                        if (resourceAsStream2 != null) {
                            resourceAsStream2.close();
                        }
                        if (createTempFile != null) {
                            Utils.deleteQuietly(createTempFile);
                        }
                        return resolve2;
                    } catch (Throwable th) {
                        if (resourceAsStream2 != null) {
                            try {
                                resourceAsStream2.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (IOException e) {
                    throw new IllegalStateException("Cannot copy jni files", e);
                }
            } catch (Throwable th3) {
                if (0 != 0) {
                    Utils.deleteQuietly((Path) null);
                }
                throw th3;
            }
        } catch (IOException e2) {
            throw new AssertionError("Failed to read PyTorch jni properties file.", e2);
        }
    }

    private static LibTorch findNativeLibrary() {
        Platform detectPlatform = Platform.detectPlatform("pytorch");
        String str = System.getenv("PYTORCH_VERSION");
        if (str == null) {
            str = System.getProperty("PYTORCH_VERSION");
        }
        if (str == null || str.isEmpty() || detectPlatform.getVersion().startsWith(str)) {
            return detectPlatform.isPlaceholder() ? downloadPyTorch(detectPlatform) : copyNativeLibraryFromClasspath(detectPlatform);
        }
        logger.warn("Override PyTorch version: {}.", str);
        return downloadPyTorch(Platform.detectPlatform("pytorch", str));
    }

    /* JADX WARN: Removed duplicated region for block: B:35:0x015f A[EXC_TOP_SPLITTER, SYNTHETIC] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private static ai.djl.pytorch.jni.LibUtils.LibTorch copyNativeLibraryFromClasspath(ai.djl.util.Platform r7) {
        /*
            Method dump skipped, instructions count: 444
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: ai.djl.pytorch.jni.LibUtils.copyNativeLibraryFromClasspath(ai.djl.util.Platform):ai.djl.pytorch.jni.LibUtils$LibTorch");
    }

    private static void loadNativeLibrary(String str) {
        logger.debug("Loading native library: {}", str);
        String property = System.getProperty("ai.djl.pytorch.native_helper");
        if (property != null && !property.isEmpty()) {
            try {
                Class.forName(property).getDeclaredMethod("load", String.class).invoke(null, str);
            } catch (ReflectiveOperationException e) {
                throw new IllegalArgumentException("Invalid native_helper: " + property, e);
            }
        }
        System.load(str);
    }

    private static LibTorch downloadPyTorch(Platform platform) {
        String version = platform.getVersion();
        String flavor = platform.getFlavor();
        String classifier = platform.getClassifier();
        String str = (Boolean.getBoolean("PYTORCH_PRECXX11") || Boolean.parseBoolean(System.getenv("PYTORCH_PRECXX11"))) ? "-precxx11" : "";
        String str2 = flavor + str;
        Path engineCacheDir = Utils.getEngineCacheDir("pytorch");
        Path resolve = engineCacheDir.resolve(version + '-' + str2 + '-' + classifier);
        if (Files.exists(resolve.resolve(NATIVE_LIB_NAME), new LinkOption[0])) {
            logger.debug("Using cache dir: {}", resolve);
            return new LibTorch(resolve.toAbsolutePath(), platform, str2);
        }
        Matcher matcher = VERSION_PATTERN.matcher(version);
        if (!matcher.matches()) {
            throw new IllegalArgumentException("Unexpected version: " + version);
        }
        String str3 = "https://publish.djl.ai/pytorch/" + matcher.group(1);
        try {
            try {
                InputStream openStream = new URL(str3 + "/files.txt").openStream();
                try {
                    Files.createDirectories(engineCacheDir, new FileAttribute[0]);
                    List<String> readLines = Utils.readLines(openStream);
                    if (str2.startsWith("cu")) {
                        Pattern compile = Pattern.compile('(' + str2.substring(0, 4) + "\\d" + str + ")/" + classifier + "/native/lib/" + NATIVE_LIB_NAME + ".gz");
                        boolean z = false;
                        Iterator it = readLines.iterator();
                        while (true) {
                            if (!it.hasNext()) {
                                break;
                            }
                            Matcher matcher2 = compile.matcher((String) it.next());
                            if (matcher2.matches()) {
                                str2 = matcher2.group(1);
                                z = true;
                                break;
                            }
                        }
                        if (!z) {
                            logger.warn("No matching cuda flavor for {} found: {}.", classifier, str2);
                            str2 = "cpu" + str;
                        }
                        resolve = engineCacheDir.resolve(version + '-' + str2 + '-' + classifier);
                        if (Files.exists(resolve.resolve(NATIVE_LIB_NAME), new LinkOption[0])) {
                            LibTorch libTorch2 = new LibTorch(resolve.toAbsolutePath(), platform, str2);
                            if (openStream != null) {
                                openStream.close();
                            }
                            return libTorch2;
                        }
                    }
                    logger.debug("Using cache dir: {}", resolve);
                    Path createTempDirectory = Files.createTempDirectory(engineCacheDir, "tmp", new FileAttribute[0]);
                    boolean z2 = false;
                    for (String str4 : readLines) {
                        if (str4.startsWith(str2 + '/' + classifier + '/')) {
                            z2 = true;
                            URL url = new URL(str3 + '/' + str4);
                            String decode = URLDecoder.decode(str4.substring(str4.lastIndexOf(47) + 1, str4.length() - 3), "UTF-8");
                            logger.info("Downloading {} ...", url);
                            GZIPInputStream gZIPInputStream = new GZIPInputStream(url.openStream());
                            try {
                                Files.copy(gZIPInputStream, createTempDirectory.resolve(decode), StandardCopyOption.REPLACE_EXISTING);
                                gZIPInputStream.close();
                            } catch (Throwable th) {
                                try {
                                    gZIPInputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                                throw th;
                            }
                        }
                    }
                    if (!z2) {
                        throw new IllegalStateException("No PyTorch native library matches your operating system: " + platform);
                    }
                    Utils.moveQuietly(createTempDirectory, resolve);
                    LibTorch libTorch3 = new LibTorch(resolve.toAbsolutePath(), platform, str2);
                    if (openStream != null) {
                        openStream.close();
                    }
                    if (createTempDirectory != null) {
                        Utils.deleteQuietly(createTempDirectory);
                    }
                    return libTorch3;
                } catch (Throwable th3) {
                    if (openStream != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
                if (0 != 0) {
                    Utils.deleteQuietly((Path) null);
                }
            }
        } catch (IOException e) {
            throw new IllegalStateException("Failed to download PyTorch native library", e);
        }
    }

    private static void downloadJniLib(Path path, Path path2, String str, String str2, String str3, String str4) {
        String str5 = "https://publish.djl.ai/pytorch/" + str2 + "/jnilib/" + str + '/' + str3 + '/' + str4 + '/' + JNI_LIB_NAME;
        logger.info("Downloading jni {} to cache ...", str5);
        try {
            try {
                InputStream openStream = new URL(str5).openStream();
                try {
                    Files.createDirectories(path, new FileAttribute[0]);
                    Path createTempFile = Files.createTempFile(path, "jni", "tmp", new FileAttribute[0]);
                    Files.copy(openStream, createTempFile, StandardCopyOption.REPLACE_EXISTING);
                    Utils.moveQuietly(createTempFile, path2);
                    if (openStream != null) {
                        openStream.close();
                    }
                    if (createTempFile != null) {
                        Utils.deleteQuietly(createTempFile);
                    }
                } catch (Throwable th) {
                    if (openStream != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } catch (IOException e) {
                throw new IllegalStateException("Cannot download jni files: " + str5, e);
            }
        } catch (Throwable th3) {
            if (0 != 0) {
                Utils.deleteQuietly((Path) null);
            }
            throw th3;
        }
    }
}
