package ai.djl.serving.wlm;

import ai.djl.serving.wlm.LmiUtils;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/LmiConfigRecommender.class */
public final class LmiConfigRecommender {
    private static final Logger logger = LoggerFactory.getLogger(LmiConfigRecommender.class);
    private static final Map<String, String> MODEL_TO_ROLLING_BATCH = Map.ofEntries(Map.entry("falcon", "lmi-dist"), Map.entry("gpt-neox", "lmi-dist"), Map.entry("t5", "lmi-dist"), Map.entry("llama", "lmi-dist"), Map.entry("mpt", "lmi-dist"), Map.entry("gpt-bigcode", "lmi-dist"), Map.entry("aquila", "lmi-dist"), Map.entry("baichuan", "lmi-dist"), Map.entry("bloom", "lmi-dist"), Map.entry("chatglm", "lmi-dist"), Map.entry("deci", "lmi-dist"), Map.entry("gemma", "lmi-dist"), Map.entry("gpt2", "lmi-dist"), Map.entry("gptj", "lmi-dist"), Map.entry("internlm2", "lmi-dist"), Map.entry("mistral", "lmi-dist"), Map.entry("mixtral", "lmi-dist"), Map.entry("opt", "lmi-dist"), Map.entry("phi", "lmi-dist"), Map.entry("qwen", "lmi-dist"), Map.entry("qwen2", "lmi-dist"), Map.entry("stablelm", "lmi-dist"), Map.entry("dbrx", "lmi-dist"), Map.entry("starcoder2", "lmi-dist"));
    private static final Set<String> OPTIMIZED_TASK_ARCHITECTURES = Set.of("ForCausalLM", "LMHeadModel", "ForConditionalGeneration");

    private LmiConfigRecommender() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void configure(ModelInfo<?, ?> modelInfo, Properties properties, LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig) {
        String envOrSystemProperty = Utils.getEnvOrSystemProperty("SERVING_FEATURES");
        setDynamicBatch(properties, huggingFaceModelConfig, modelInfo, envOrSystemProperty);
        setRollingBatch(properties, huggingFaceModelConfig, envOrSystemProperty);
        setEngine(properties, huggingFaceModelConfig, envOrSystemProperty);
        setTensorParallelDegree(properties);
    }

    private static void setRollingBatch(Properties properties, LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig, String str) {
        if (Integer.parseInt(properties.getProperty("batch_size", "1")) > 1) {
            properties.setProperty("option.rolling_batch", "disable");
            return;
        }
        String property = properties.getProperty("option.rolling_batch", "auto");
        if ("auto".equals(property)) {
            if (!isTextGenerationModel(huggingFaceModelConfig)) {
                property = "disable";
            } else if (isVLLMEnabled(str) && isLmiDistEnabled(str)) {
                property = MODEL_TO_ROLLING_BATCH.getOrDefault(huggingFaceModelConfig.getModelType(), "auto");
            } else if (LmiUtils.isTrtLLMRollingBatch(properties)) {
                property = "trtllm";
            }
            properties.setProperty("option.rolling_batch", property);
        }
    }

    private static void setEngine(Properties properties, LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig, String str) {
        if (properties.containsKey("engine")) {
            return;
        }
        String str2 = "Python";
        String property = properties.getProperty("option.rolling_batch");
        if ("lmi-dist".equals(property) || "trtllm".equals(property)) {
            str2 = "MPI";
            properties.setProperty("option.mpi_mode", "true");
        }
        if (isT5TrtLLM(huggingFaceModelConfig, str)) {
            str2 = "MPI";
            properties.setProperty("option.mpi_mode", "true");
        }
        properties.setProperty("engine", str2);
    }

    private static void setTensorParallelDegree(Properties properties) {
        if (properties.containsKey("option.tensor_parallel_degree")) {
            return;
        }
        String str = Utils.getenv("TENSOR_PARALLEL_DEGREE", "max");
        if ("max".equals(str)) {
            str = String.valueOf(CudaUtils.getGpuCount());
        }
        properties.setProperty("option.tensor_parallel_degree", str);
    }

    private static void setDynamicBatch(Properties properties, LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig, ModelInfo<?, ?> modelInfo, String str) {
        if (isT5TrtLLM(huggingFaceModelConfig, str)) {
            properties.setProperty("trtllm_python_backend", String.valueOf(true));
            properties.setProperty("option.rolling_batch", "disable");
            if (Integer.parseInt(properties.getProperty("batch_size", "0")) == 0) {
                modelInfo.batchSize = 32;
                properties.setProperty("batch_size", String.valueOf(32));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void setRollingBatchSize(Properties properties) {
        if (properties.containsKey("option.max_rolling_batch_size")) {
            return;
        }
        String property = properties.getProperty("option.rolling_batch");
        int i = 32;
        if ("vllm".equals(property) || "lmi-dist".equals(property)) {
            i = 256;
        }
        if (("trtllm".equals(property) || ("auto".equals(property) && isTrtLLMEnabled(Utils.getEnvOrSystemProperty("SERVING_FEATURES")))) && properties.containsKey("option.max_num_tokens")) {
            i = 256;
        }
        properties.setProperty("option.max_rolling_batch_size", String.valueOf(i));
    }

    private static boolean isVLLMEnabled(String str) {
        return str != null && str.contains("vllm");
    }

    private static boolean isLmiDistEnabled(String str) {
        return str != null && str.contains("lmi-dist");
    }

    private static boolean isTrtLLMEnabled(String str) {
        return str != null && str.contains("trtllm");
    }

    private static boolean isT5TrtLLM(LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig, String str) {
        return isTrtLLMEnabled(str) && "t5".equals(huggingFaceModelConfig.getModelType());
    }

    private static boolean isTextGenerationModel(LmiUtils.HuggingFaceModelConfig huggingFaceModelConfig) {
        for (String str : huggingFaceModelConfig.getArchitectures()) {
            Stream<String> stream = OPTIMIZED_TASK_ARCHITECTURES.stream();
            Objects.requireNonNull(str);
            if (stream.anyMatch(str::endsWith)) {
                return true;
            }
        }
        logger.warn("The model task architecture {} is not supported for optimized inference. LMI will attempt to load the model using HuggingFace Accelerate. Optimized inference performance is only available for the following task architectures: {}", huggingFaceModelConfig.getArchitectures(), OPTIMIZED_TASK_ARCHITECTURES);
        return false;
    }
}
