package ai.djl.serving.wlm;

import ai.djl.serving.util.ConfigManager;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager.class */
public class WorkLoadManager {
    private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class);
    private GpuAssignmentStrategy gpuAssignmentStrategy;
    private ExecutorService threadPool = Executors.newCachedThreadPool();
    private ConcurrentHashMap<String, WorkerPool> workerPools = new ConcurrentHashMap<>();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager$WorkerPool.class */
    public static final class WorkerPool {
        private List<WorkerThread> workers = Collections.synchronizedList(new ArrayList());
        private LinkedBlockingDeque<Job> jobQueue;
        private String modelName;

        public WorkerPool(ModelInfo modelInfo) {
            this.jobQueue = new LinkedBlockingDeque<>(modelInfo.getQueueSize());
            this.modelName = modelInfo.getModelName();
        }

        public List<WorkerThread> getWorkers() {
            return this.workers;
        }

        public LinkedBlockingDeque<Job> getJobQueue() {
            return this.jobQueue;
        }

        public void log() {
            if (WorkLoadManager.logger.isDebugEnabled()) {
                StringBuffer stringBuffer = new StringBuffer();
                this.workers.forEach(workerThread -> {
                    stringBuffer.append(workerThread.getWorkerId());
                    if (workerThread.isFixPoolThread()) {
                        stringBuffer.append("-fixedPool\n");
                    } else {
                        stringBuffer.append("-tmpPool\n");
                    }
                });
                WorkLoadManager.logger.debug("worker pool for model {}:\n {}", this.modelName, stringBuffer.toString());
            }
        }

        public void cleanup() {
            this.workers.removeIf(workerThread -> {
                return workerThread.getState() == WorkerState.WORKER_STOPPED || workerThread.getState() == WorkerState.WORKER_ERROR;
            });
        }
    }

    public WorkLoadManager(ConfigManager configManager) {
        this.gpuAssignmentStrategy = new RoundRobinGpuAssignmentStrategy(configManager);
    }

    public List<WorkerThread> getWorkers(String str) {
        List<WorkerThread> workers;
        WorkerPool workerPool = this.workerPools.get(str);
        if (workerPool == null) {
            workers = Collections.emptyList();
        } else {
            workers = workerPool.getWorkers();
            if (workers == null) {
                workers = Collections.emptyList();
            }
        }
        return workers;
    }

    public boolean addJob(ModelInfo modelInfo, Job job) {
        boolean z = false;
        WorkerPool workerPoolForModel = getWorkerPoolForModel(modelInfo);
        if (getNumRunningWorkers(modelInfo.getModelName()) > 0) {
            try {
                z = workerPoolForModel.getJobQueue().offer(job);
                if (!z) {
                    synchronized (modelInfo.getModelName()) {
                        scaleUpWorkers(modelInfo, workerPoolForModel);
                        z = workerPoolForModel.getJobQueue().offer(job, modelInfo.getMaxBatchDelay(), TimeUnit.MILLISECONDS);
                    }
                }
            } catch (InterruptedException e) {
                logger.info("Worker Queue Capacity Exceeded. cannot add to worker queue in appropriate time. You can configure max batch delay time for this model.");
            }
        }
        return z;
    }

    private void scaleUpWorkers(ModelInfo modelInfo, WorkerPool workerPool) {
        int numRunningWorkers = getNumRunningWorkers(modelInfo.getModelName());
        if (numRunningWorkers >= modelInfo.getMaxWorkers()) {
            logger.warn("scale up capacity of {} workers reached. Unable to scale up worker pool.", Integer.valueOf(modelInfo.getMaxWorkers()));
        } else {
            logger.debug("scaling up workers for model {} to {} ", modelInfo, Integer.valueOf(numRunningWorkers + 1));
            addThreads(workerPool.getWorkers(), modelInfo, 1, false);
        }
    }

    public int getNumRunningWorkers(String str) {
        int i = 0;
        WorkerPool workerPool = this.workerPools.get(str);
        if (workerPool != null) {
            workerPool.cleanup();
            for (WorkerThread workerThread : workerPool.getWorkers()) {
                if (workerThread.getState() != WorkerState.WORKER_STOPPED && workerThread.getState() != WorkerState.WORKER_ERROR && workerThread.getState() != WorkerState.WORKER_SCALED_DOWN) {
                    i++;
                }
            }
        }
        return i;
    }

    public void modelChanged(ModelInfo modelInfo) {
        synchronized (modelInfo.getModelName()) {
            int minWorkers = modelInfo.getMinWorkers();
            WorkerPool workerPoolForModel = getWorkerPoolForModel(modelInfo);
            if (workerPoolForModel != null) {
                workerPoolForModel.cleanup();
                if (minWorkers == 0) {
                    this.workerPools.remove(modelInfo.getModelName());
                }
                List<WorkerThread> workers = workerPoolForModel.getWorkers();
                List list = (List) workers.stream().filter((v0) -> {
                    return v0.isFixPoolThread();
                }).collect(Collectors.toList());
                int size = list.size();
                if (size < minWorkers) {
                    addThreads(workers, modelInfo, minWorkers - size, true);
                } else {
                    list.subList(minWorkers, size).forEach(workerThread -> {
                        workers.remove(workerThread);
                        workerThread.shutdown(WorkerState.WORKER_SCALED_DOWN);
                    });
                }
                workerPoolForModel.log();
            }
        }
    }

    private WorkerPool getWorkerPoolForModel(ModelInfo modelInfo) {
        return this.workerPools.computeIfAbsent(modelInfo.getModelName(), str -> {
            return new WorkerPool(modelInfo);
        });
    }

    private void addThreads(List<WorkerThread> list, ModelInfo modelInfo, int i, boolean z) {
        for (int i2 = 0; i2 < i; i2++) {
            WorkerThread build = WorkerThread.builder().setModel(modelInfo).setJobQueue(getWorkerPoolForModel(modelInfo).getJobQueue()).optGpuAssignmentStrategy(this.gpuAssignmentStrategy).optFixPoolThread(z).build();
            list.add(build);
            this.threadPool.submit(build);
        }
    }
}
