package ai.djl.serving.wlm;

import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.util.ConfigManager;
import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/ModelManager.class */
public final class ModelManager {
    private static final Logger logger = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager modelManager;
    private ConfigManager configManager;
    private WorkLoadManager wlm;
    private ConcurrentHashMap<String, ModelInfo> models = new ConcurrentHashMap<>();
    private Set<String> startupModels = new HashSet();

    private ModelManager(ConfigManager configManager) {
        this.configManager = configManager;
        this.wlm = new WorkLoadManager(configManager);
    }

    public static void init(ConfigManager configManager) {
        modelManager = new ModelManager(configManager);
    }

    public static ModelManager getInstance() {
        return modelManager;
    }

    public CompletableFuture<ModelInfo> registerModel(String str, String str2, int i, int i2, int i3) {
        return CompletableFuture.supplyAsync(() -> {
            try {
                ZooModel loadModel = ModelZoo.loadModel(Criteria.builder().setTypes(Input.class, Output.class).optModelUrls(str2).build());
                ModelInfo modelInfo = new ModelInfo(str, str2, loadModel, this.configManager.getJobQueueSize(), i3, i2, i);
                if (this.models.putIfAbsent(str, modelInfo) != null) {
                    loadModel.close();
                    throw new BadRequestException("Model " + str + " is already registered.");
                }
                logger.info("Model {} loaded.", modelInfo.getModelName());
                return modelInfo;
            } catch (ModelException | IOException e) {
                throw new CompletionException((Throwable) e);
            }
        });
    }

    public boolean unregisterModel(String str) {
        ModelInfo remove = this.models.remove(str);
        if (remove == null) {
            logger.warn("Model not found: " + str);
            return false;
        }
        ModelInfo scaleWorkers = remove.scaleWorkers(0, 0);
        this.wlm.modelChanged(scaleWorkers);
        this.startupModels.remove(str);
        scaleWorkers.close();
        logger.info("Model {} unregistered.", str);
        return true;
    }

    public void triggerModelUpdated(ModelInfo modelInfo) {
        if (!this.models.containsKey(modelInfo.getModelName())) {
            throw new AssertionError("Model not found: " + modelInfo.getModelName());
        }
        logger.debug("updateModel: {}", modelInfo.getModelName());
        this.models.put(modelInfo.getModelName(), modelInfo);
        this.wlm.modelChanged(modelInfo);
    }

    public Map<String, ModelInfo> getModels() {
        return this.models;
    }

    public Set<String> getStartupModels() {
        return this.startupModels;
    }

    public boolean addJob(Job job) throws ModelNotFoundException {
        String modelName = job.getModelName();
        ModelInfo modelInfo = this.models.get(modelName);
        if (modelInfo == null) {
            throw new ModelNotFoundException("Model not found: " + modelName);
        }
        return this.wlm.addJob(modelInfo, job);
    }

    public DescribeModelResponse describeModel(String str) throws ModelNotFoundException {
        ModelInfo modelInfo = this.models.get(str);
        if (modelInfo == null) {
            throw new ModelNotFoundException("Model not found: " + str);
        }
        DescribeModelResponse describeModelResponse = new DescribeModelResponse();
        describeModelResponse.setModelName(str);
        describeModelResponse.setModelUrl(modelInfo.getModelUrl());
        describeModelResponse.setBatchSize(modelInfo.getBatchSize());
        describeModelResponse.setMaxBatchDelay(modelInfo.getMaxBatchDelay());
        describeModelResponse.setMaxWorkers(modelInfo.getMaxWorkers());
        describeModelResponse.setMinWorkers(modelInfo.getMinWorkers());
        describeModelResponse.setMaxIdleTime(modelInfo.getMaxIdleTime());
        describeModelResponse.setLoadedAtStartup(this.startupModels.contains(str));
        describeModelResponse.setStatus(this.wlm.getNumRunningWorkers(str) >= modelInfo.getMinWorkers() ? "Healthy" : "Unhealthy");
        for (WorkerThread workerThread : this.wlm.getWorkers(str)) {
            describeModelResponse.addWorker(workerThread.getWorkerId(), workerThread.getStartTime(), workerThread.isRunning(), workerThread.getGpuId());
        }
        return describeModelResponse;
    }

    public CompletableFuture<String> workerStatus() {
        return CompletableFuture.supplyAsync(() -> {
            String str = "Healthy";
            int i = 0;
            int i2 = 0;
            for (Map.Entry<String, ModelInfo> entry : this.models.entrySet()) {
                i2 += entry.getValue().getMinWorkers();
                i += this.wlm.getNumRunningWorkers(entry.getValue().getModelName());
            }
            if (i > 0 && i < i2) {
                str = "Partial Healthy";
            } else if (i == 0 && i2 > 0) {
                str = "Unhealthy";
            }
            return str;
        });
    }
}
