package ai.djl.serving.wlm;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.EngineException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import com.google.gson.JsonSyntaxException;
import com.google.gson.annotations.SerializedName;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URLConnection;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/LmiUtils.class */
public final class LmiUtils {
    private static final Logger logger = LoggerFactory.getLogger(LmiUtils.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/djl/serving/wlm/LmiUtils$HuggingFaceModelConfig.class */
    public static final class HuggingFaceModelConfig {

        @SerializedName("model_type")
        private String modelType;

        @SerializedName("architectures")
        private List<String> configArchitectures;

        @SerializedName("auto_map")
        private Map<String, String> autoMap;

        @SerializedName("_diffusers_version")
        private String diffusersVersion;
        private Set<String> allArchitectures;

        HuggingFaceModelConfig() {
        }

        public String getModelType() {
            if (this.modelType != null) {
                return this.modelType;
            }
            if (this.diffusersVersion == null) {
                return null;
            }
            return "stable-diffusion";
        }

        public Set<String> getArchitectures() {
            if (this.allArchitectures == null) {
                determineAllArchitectures();
            }
            return this.allArchitectures;
        }

        private void determineAllArchitectures() {
            this.allArchitectures = new HashSet();
            if (this.configArchitectures != null) {
                this.allArchitectures.addAll(this.configArchitectures);
            }
            if (this.autoMap != null) {
                this.allArchitectures.addAll(this.autoMap.keySet());
            }
        }
    }

    private LmiUtils() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static String inferLmiEngine(ModelInfo<?, ?> modelInfo) throws ModelException {
        Properties properties = modelInfo.getProperties();
        HuggingFaceModelConfig huggingFaceModelConfig = getHuggingFaceModelConfig(modelInfo);
        if (huggingFaceModelConfig == null) {
            String str = isTrtLLMRollingBatch(properties) ? "MPI" : "Python";
            logger.info("No config.json found, use {} engine.", str);
            return str;
        }
        LmiConfigRecommender.configure(modelInfo, properties, huggingFaceModelConfig);
        logger.info("Detected engine: {}, rolling_batch: {}, tensor_parallel_degree {}, for modelType: {}", new Object[]{properties.getProperty("engine"), properties.getProperty("option.rolling_batch"), properties.getProperty("option.tensor_parallel_degree"), huggingFaceModelConfig.getModelType()});
        return properties.getProperty("engine");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isTrtLLMRollingBatch(Properties properties) {
        String envOrSystemProperty;
        String property = properties.getProperty("option.rolling_batch");
        if ("trtllm".equals(property)) {
            return true;
        }
        return (property == null || "auto".equals(property)) && (envOrSystemProperty = Utils.getEnvOrSystemProperty("SERVING_FEATURES")) != null && envOrSystemProperty.contains("trtllm");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean isRollingBatchEnabled(Properties properties) {
        String property = properties.getProperty("option.rolling_batch");
        return (null == property || "disable".equals(property)) ? false : true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static boolean needConvert(ModelInfo<?, ?> modelInfo) {
        return isTrtLLMRollingBatch(modelInfo.getProperties()) || modelInfo.getProperties().containsKey("trtllm_python_backend");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void convertTrtLLM(ModelInfo<?, ?> modelInfo) throws IOException {
        Path path;
        String str = null;
        if (modelInfo.downloadDir != null) {
            path = modelInfo.downloadDir;
        } else {
            path = modelInfo.modelDir;
            str = modelInfo.prop.getProperty("option.model_id");
            if (str != null && Files.isDirectory(Paths.get(str, new String[0]), new LinkOption[0])) {
                path = Paths.get(str, new String[0]);
            }
        }
        if (str == null) {
            str = path.toString();
        }
        String property = modelInfo.prop.getProperty("option.tensor_parallel_degree");
        if (property == null) {
            property = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
        }
        if ("max".equals(property)) {
            property = String.valueOf(CudaUtils.getGpuCount());
        }
        if (modelInfo.prop.containsKey("trtllm_python_backend")) {
            modelInfo.downloadDir = buildTrtLlmArtifacts(modelInfo.modelDir, str, property);
            return;
        }
        modelInfo.prop.put("option.rolling_batch", "trtllm");
        if (isValidTrtLlmModelRepo(path)) {
            return;
        }
        modelInfo.downloadDir = buildTrtLlmArtifacts(modelInfo.modelDir, str, property);
    }

    public static URI generateHuggingFaceConfigUri(ModelInfo<?, ?> modelInfo, String str) throws ModelException, IOException {
        URI uri = null;
        Path path = modelInfo.modelDir;
        if (str != null && str.startsWith("s3://")) {
            modelInfo.downloadS3();
            Path path2 = modelInfo.downloadDir;
            if (Files.isRegularFile(path2.resolve("config.json"), new LinkOption[0])) {
                uri = path2.resolve("config.json").toUri();
            } else if (Files.isRegularFile(path2.resolve("model_index.json"), new LinkOption[0])) {
                uri = path2.resolve("model_index.json").toUri();
            }
        } else if (str != null) {
            modelInfo.prop.setProperty("option.model_id", str);
            Path path3 = Paths.get(str, new String[0]);
            if (Files.isDirectory(path3, new LinkOption[0])) {
                Path resolve = path3.resolve("config.json");
                if (Files.isRegularFile(resolve, new LinkOption[0])) {
                    return resolve.toUri();
                }
                Path resolve2 = path3.resolve("model_index.json");
                if (Files.isRegularFile(resolve2, new LinkOption[0])) {
                    return resolve2.toUri();
                }
                return null;
            }
            String str2 = Utils.getenv("HF_TOKEN");
            uri = URI.create("https://huggingface.co/" + str + "/raw/main/config.json");
            HttpURLConnection httpURLConnection = (HttpURLConnection) uri.toURL().openConnection();
            if (str2 != null) {
                httpURLConnection.setRequestProperty("Authorization", "Bearer " + str2);
            }
            if (200 != httpURLConnection.getResponseCode()) {
                uri = URI.create("https://huggingface.co/" + str + "/raw/main/model_index.json");
            }
        } else if (Files.isRegularFile(path.resolve("config.json"), new LinkOption[0])) {
            uri = path.resolve("config.json").toUri();
        } else if (Files.isRegularFile(path.resolve("model_index.json"), new LinkOption[0])) {
            uri = path.resolve("model_index.json").toUri();
        }
        return uri;
    }

    private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo<?, ?> modelInfo) throws ModelException {
        String property = modelInfo.prop.getProperty("option.model_id");
        try {
            URI generateHuggingFaceConfigUri = generateHuggingFaceConfigUri(modelInfo, property);
            if (generateHuggingFaceConfigUri == null) {
                return null;
            }
            URLConnection openConnection = generateHuggingFaceConfigUri.toURL().openConnection();
            if (Utils.getenv("HF_TOKEN") != null) {
                openConnection.setRequestProperty("Authorization", "Bearer " + Utils.getenv("HF_TOKEN"));
            }
            InputStream inputStream = openConnection.getInputStream();
            try {
                HuggingFaceModelConfig huggingFaceModelConfig = (HuggingFaceModelConfig) JsonUtils.GSON.fromJson(Utils.toString(inputStream), HuggingFaceModelConfig.class);
                if (inputStream != null) {
                    inputStream.close();
                }
                return huggingFaceModelConfig;
            } finally {
            }
        } catch (IOException | JsonSyntaxException e) {
            throw new ModelNotFoundException("Invalid huggingface model id: " + property, e);
        }
    }

    private static Path buildTrtLlmArtifacts(Path path, String str, String str2) throws IOException {
        logger.info("Converting model to TensorRT-LLM artifacts");
        String hash = Utils.hash(str + str2);
        String str3 = Utils.getenv("SERVING_DOWNLOAD_DIR", (String) null);
        Path resolve = (str3 == null ? Utils.getCacheDir() : Paths.get(str3, new String[0])).resolve("trtllm").resolve(hash);
        if (Files.exists(resolve, new LinkOption[0])) {
            logger.info("TensorRT-LLM artifacts already converted: {}", resolve);
            return resolve;
        }
        try {
            try {
                Process start = new ProcessBuilder("python", "/opt/djl/partition/trt_llm_partition.py", "--properties_dir", path.toAbsolutePath().toString(), "--trt_llm_model_repo", resolve.toString(), "--tensor_parallel_degree", str2, "--model_path", str).redirectErrorStream(true).start();
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(start.getInputStream(), StandardCharsets.UTF_8));
                while (true) {
                    try {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        logger.info("convert_py: {}", readLine);
                    } catch (Throwable th) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                        throw th;
                    }
                }
                bufferedReader.close();
                if (0 != start.waitFor()) {
                    throw new EngineException("Model conversion process failed!");
                }
                logger.info("TensorRT-LLM artifacts built successfully");
                if (1 == 0) {
                    Utils.deleteQuietly(resolve);
                }
                return resolve;
            } catch (InterruptedException e) {
                throw new IOException("Failed to build TensorRT-LLM artifacts", e);
            }
        } catch (Throwable th3) {
            if (0 == 0) {
                Utils.deleteQuietly(resolve);
            }
            throw th3;
        }
    }

    static String getAWSGpuMachineType() {
        String computeCapability = CudaUtils.getComputeCapability(0);
        double max = ((CudaUtils.getGpuMemory(Device.gpu()).getMax() / 1024.0d) / 1024.0d) / 1024.0d;
        if ("7.5".equals(computeCapability)) {
            return "g4";
        }
        if ("8.0".equals(computeCapability)) {
            return max > 45.0d ? "p4de" : "p4d";
        }
        if ("8.6".equals(computeCapability)) {
            return "g5";
        }
        if ("8.9".equals(computeCapability)) {
            return max > 25.0d ? "g6e" : "g6";
        }
        if ("9.0".equals(computeCapability)) {
            return "p5";
        }
        logger.warn("Could not identify GPU arch " + computeCapability);
        return null;
    }

    static boolean isValidTrtLlmModelRepo(Path path) throws IOException {
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        Stream<Path> list = Files.list(path);
        try {
            list.filter(path2 -> {
                return Files.isDirectory(path2, new LinkOption[0]);
            }).forEach(path3 -> {
                Path resolve = path3.resolve("config.pbtxt");
                Path resolve2 = path3.resolve("tokenizer_config.json");
                if (Files.isRegularFile(resolve, new LinkOption[0]) && Files.isRegularFile(resolve2, new LinkOption[0])) {
                    logger.info("Found triton model: {}", path3);
                    atomicBoolean.set(true);
                }
            });
            if (list != null) {
                list.close();
            }
            return atomicBoolean.get();
        } catch (Throwable th) {
            if (list != null) {
                try {
                    list.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }
}
