package ai.djl.serving.wlm;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.serving.http.BadRequestException;
import ai.djl.serving.http.DescribeModelResponse;
import ai.djl.serving.util.ConfigManager;
import java.io.IOException;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/ModelManager.class */
public final class ModelManager {
    private static final Logger logger = LoggerFactory.getLogger(ModelManager.class);
    private static ModelManager modelManager;
    private ConfigManager configManager;
    private WorkLoadManager wlm = new WorkLoadManager();
    private Map<String, Endpoint> endpoints = new ConcurrentHashMap();
    private Set<String> startupModels = new HashSet();

    private ModelManager(ConfigManager configManager) {
        this.configManager = configManager;
    }

    public static void init(ConfigManager configManager) {
        modelManager = new ModelManager(configManager);
    }

    public static ModelManager getInstance() {
        return modelManager;
    }

    public CompletableFuture<ModelInfo> registerModel(String str, String str2, String str3, String str4, int i, int i2, int i3, int i4) {
        return CompletableFuture.supplyAsync(() -> {
            try {
                Criteria.Builder optEngine = Criteria.builder().setTypes(Input.class, Output.class).optModelUrls(str3).optEngine(str4);
                if (i != -1) {
                    optEngine.optDevice(Device.gpu(i));
                }
                ZooModel loadModel = optEngine.build().loadModel();
                ModelInfo modelInfo = new ModelInfo(str, str2, str3, loadModel, this.configManager.getJobQueueSize(), i4, i3, i2);
                if (this.endpoints.computeIfAbsent(str, str5 -> {
                    return new Endpoint();
                }).add(modelInfo)) {
                    logger.info("Model {} loaded.", str);
                    return modelInfo;
                }
                loadModel.close();
                throw new BadRequestException("Model " + modelInfo + " is already registered.");
            } catch (ModelException | IOException e) {
                throw new CompletionException((Throwable) e);
            }
        });
    }

    public boolean unregisterModel(String str, String str2) {
        Endpoint endpoint = this.endpoints.get(str);
        if (endpoint == null) {
            logger.warn("Model not found: " + str);
            return false;
        }
        if (str2 == null) {
            for (ModelInfo modelInfo : endpoint.getModels()) {
                modelInfo.scaleWorkers(0, 0);
                this.wlm.modelChanged(modelInfo);
                this.startupModels.remove(str);
                modelInfo.close();
            }
            logger.info("Model {} unregistered.", str);
        } else {
            ModelInfo remove = endpoint.remove(str2);
            if (remove == null) {
                logger.warn("Model not found: " + str + ':' + str2);
                return false;
            }
            remove.scaleWorkers(0, 0);
            this.wlm.modelChanged(remove);
            this.startupModels.remove(str);
            remove.close();
        }
        if (!endpoint.getModels().isEmpty()) {
            return true;
        }
        this.endpoints.remove(str);
        return true;
    }

    public ModelInfo triggerModelUpdated(ModelInfo modelInfo) {
        logger.debug("updateModel: {}", modelInfo.getModelName());
        this.wlm.modelChanged(modelInfo);
        return modelInfo;
    }

    public Map<String, Endpoint> getEndpoints() {
        return this.endpoints;
    }

    public ModelInfo getModel(String str, String str2, boolean z) {
        Endpoint endpoint = this.endpoints.get(str);
        if (endpoint == null) {
            return null;
        }
        if (str2 != null) {
            return endpoint.get(str2);
        }
        if (endpoint.getModels().isEmpty()) {
            return null;
        }
        return z ? endpoint.next() : endpoint.getModels().get(0);
    }

    public Set<String> getStartupModels() {
        return this.startupModels;
    }

    public boolean addJob(Job job) throws ModelNotFoundException {
        return this.wlm.addJob(job);
    }

    public DescribeModelResponse describeModel(String str, String str2) throws ModelNotFoundException {
        ModelInfo model = getModel(str, str2, false);
        if (model == null) {
            throw new ModelNotFoundException("Model not found: " + str);
        }
        DescribeModelResponse describeModelResponse = new DescribeModelResponse();
        describeModelResponse.setModelName(str);
        describeModelResponse.setModelUrl(model.getModelUrl());
        describeModelResponse.setBatchSize(model.getBatchSize());
        describeModelResponse.setMaxBatchDelay(model.getMaxBatchDelay());
        describeModelResponse.setMaxWorkers(model.getMaxWorkers());
        describeModelResponse.setMinWorkers(model.getMinWorkers());
        describeModelResponse.setMaxIdleTime(model.getMaxIdleTime());
        describeModelResponse.setLoadedAtStartup(this.startupModels.contains(str));
        describeModelResponse.setStatus(this.wlm.getNumRunningWorkers(model) >= model.getMinWorkers() ? "Healthy" : "Unhealthy");
        for (WorkerThread workerThread : this.wlm.getWorkers(model)) {
            describeModelResponse.addWorker(workerThread.getWorkerId(), workerThread.getStartTime(), workerThread.isRunning(), workerThread.getGpuId());
        }
        return describeModelResponse;
    }

    public CompletableFuture<String> workerStatus() {
        return CompletableFuture.supplyAsync(() -> {
            String str = "Healthy";
            int i = 0;
            int i2 = 0;
            Iterator<Endpoint> it = this.endpoints.values().iterator();
            while (it.hasNext()) {
                for (ModelInfo modelInfo : it.next().getModels()) {
                    i2 += modelInfo.getMinWorkers();
                    i += this.wlm.getNumRunningWorkers(modelInfo);
                }
            }
            if (i > 0 && i < i2) {
                str = "Partial Healthy";
            } else if (i == 0 && i2 > 0) {
                str = "Unhealthy";
            }
            return str;
        });
    }
}
