package ai.djl.serving.wlm;

import ai.djl.serving.wlm.WorkerPoolConfig;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.serving.wlm.util.WorkerJob;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager.class */
public class WorkLoadManager {
    private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class);
    private ExecutorService threadPool = Executors.newCachedThreadPool(runnable -> {
        Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
        newThread.setDaemon(true);
        return newThread;
    });
    private ConcurrentHashMap<WorkerPoolConfig<?, ?>, WorkerPool<?, ?>> workerPools = new ConcurrentHashMap<>();

    public <I, O> WorkerPool<I, O> registerWorkerPool(WorkerPoolConfig<I, O> workerPoolConfig) {
        return (WorkerPool) this.workerPools.computeIfAbsent(workerPoolConfig, workerPoolConfig2 -> {
            return new WorkerPool(workerPoolConfig, this.threadPool);
        });
    }

    public void unregisterWorkerPool(WorkerPoolConfig<?, ?> workerPoolConfig) {
        WorkerPool workerPool = getWorkerPool(workerPoolConfig);
        if (workerPool.decreaseRef() <= 0) {
            logger.info("Unloading model: {}", workerPoolConfig);
            workerPool.shutdownWorkers();
            this.workerPools.remove(workerPoolConfig);
        }
    }

    public <I, O> CompletableFuture<O> runJob(Job<I, O> job) {
        CompletableFuture<O> completableFuture = new CompletableFuture<>();
        WorkerPoolConfig<I, O> wpc = job.getWpc();
        if (wpc.getStatus() != WorkerPoolConfig.Status.READY) {
            completableFuture.completeExceptionally(new WlmException("Model is not ready: " + wpc.getId()));
            return completableFuture;
        }
        WorkerPool<I, O> workerPool = getWorkerPool(wpc);
        int maxWorkers = workerPool.getMaxWorkers();
        if (maxWorkers == 0) {
            completableFuture.completeExceptionally(new WlmShutdownException("All model workers has been shutdown: " + wpc.getId()));
            return completableFuture;
        }
        LinkedBlockingDeque<WorkerJob<I, O>> jobQueue = workerPool.getJobQueue();
        if ((jobQueue.remainingCapacity() == 1 && workerPool.isAllWorkerBusy()) || workerPool.isAllWorkerDied() || !jobQueue.offer(new WorkerJob<>(job, completableFuture))) {
            completableFuture.completeExceptionally(new WlmCapacityException("Worker queue capacity exceeded for model: " + wpc.getId()));
            scaleUp(workerPool, wpc, maxWorkers);
            return completableFuture;
        }
        int numRunningWorkers = getNumRunningWorkers(wpc);
        if (numRunningWorkers == 0 || (numRunningWorkers < maxWorkers && jobQueue.size() > wpc.getBatchSize() * 2)) {
            scaleUp(workerPool, wpc, maxWorkers);
        }
        return completableFuture;
    }

    private <I, O> void scaleUp(WorkerPool<I, O> workerPool, WorkerPoolConfig<I, O> workerPoolConfig, int i) {
        synchronized (workerPool) {
            int numRunningWorkers = getNumRunningWorkers(workerPoolConfig);
            if (numRunningWorkers < i) {
                logger.info("Scaling up workers for model {} to {} ", workerPoolConfig, Integer.valueOf(numRunningWorkers + 1));
                workerPool.addThreads();
            }
        }
    }

    public int getNumRunningWorkers(WorkerPoolConfig<?, ?> workerPoolConfig) {
        int i = 0;
        WorkerPool<?, ?> workerPool = this.workerPools.get(workerPoolConfig);
        if (workerPool != null) {
            workerPool.cleanup();
            for (WorkerThread<?, ?> workerThread : workerPool.getWorkers()) {
                if (workerThread.getState() != WorkerState.WORKER_STOPPED && workerThread.getState() != WorkerState.WORKER_ERROR && workerThread.getState() != WorkerState.WORKER_SCALED_DOWN) {
                    i++;
                }
            }
        }
        return i;
    }

    public <I, O> WorkerPool<I, O> getWorkerPoolById(String str) {
        for (Map.Entry<WorkerPoolConfig<?, ?>, WorkerPool<?, ?>> entry : this.workerPools.entrySet()) {
            if (str.equals(entry.getKey().getId())) {
                return (WorkerPool) entry.getValue();
            }
        }
        return null;
    }

    public <I, O> WorkerPool<I, O> getWorkerPool(WorkerPoolConfig<I, O> workerPoolConfig) {
        return (WorkerPool) this.workerPools.get(workerPoolConfig);
    }

    public void close() {
        this.threadPool.shutdownNow();
        Iterator<WorkerPool<?, ?>> it = this.workerPools.values().iterator();
        while (it.hasNext()) {
            it.next().shutdown();
        }
        this.workerPools.clear();
    }
}
