package ai.djl.serving.wlm;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.util.WorkerJob;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/WorkerPool.class */
public class WorkerPool<I, O> {
    private static final Logger logger = LoggerFactory.getLogger(WorkerPool.class);
    private final ModelInfo<I, O> model;
    private ExecutorService threadPool;
    private Map<Device, WorkerGroup<I, O>> workerGroups = new ConcurrentHashMap();
    private LinkedBlockingDeque<WorkerJob<I, O>> jobQueue;

    /* JADX INFO: Access modifiers changed from: package-private */
    public WorkerPool(ModelInfo<I, O> modelInfo, ExecutorService executorService) {
        this.model = modelInfo;
        this.threadPool = executorService;
        this.jobQueue = new LinkedBlockingDeque<>(modelInfo.getQueueSize());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ModelInfo<I, O> getModel() {
        return this.model;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ExecutorService getThreadPool() {
        return this.threadPool;
    }

    public Map<Device, WorkerGroup<I, O>> getWorkerGroups() {
        return this.workerGroups;
    }

    public List<WorkerThread<I, O>> getWorkers() {
        return (List) this.workerGroups.values().stream().flatMap(workerGroup -> {
            return workerGroup.workers.stream();
        }).collect(Collectors.toList());
    }

    public LinkedBlockingDeque<WorkerJob<I, O>> getJobQueue() {
        return this.jobQueue;
    }

    public int getMaxWorkers() {
        return this.workerGroups.values().stream().mapToInt(workerGroup -> {
            return workerGroup.maxWorkers;
        }).reduce(0, Integer::sum);
    }

    public boolean isFullyScaled() {
        for (WorkerGroup<I, O> workerGroup : this.workerGroups.values()) {
            if (workerGroup.getMinWorkers() > workerGroup.getWorkers().size()) {
                return false;
            }
        }
        return true;
    }

    public void initWorkers(String str, int i, int i2) {
        Device withDefaultDevice = this.model.withDefaultDevice(str);
        logger.info("initWorkers for {} ({}): {}, {}", new Object[]{this.model, withDefaultDevice, Integer.valueOf(i), Integer.valueOf(i2)});
        synchronized (this.model) {
            try {
                this.model.load(withDefaultDevice);
                if (this.model.getStatus() != ModelInfo.Status.READY) {
                    logger.warn("Cannot scale workers while model is not READY: {}", this.model);
                }
            } catch (ModelException | IOException e) {
                throw new CompletionException((Throwable) e);
            }
        }
        cleanup();
        WorkerGroup<I, O> computeIfAbsent = this.workerGroups.computeIfAbsent(withDefaultDevice, device -> {
            return new WorkerGroup(this, device);
        });
        computeIfAbsent.configureWorkers(i, i2);
        doScaleWorker(computeIfAbsent);
        String property = this.model.getModel(withDefaultDevice).getProperty("job_queue_size");
        if (property != null && !property.isEmpty()) {
            this.model.setQueueSize(Integer.parseInt(property));
        }
        log();
    }

    public void scaleWorkers(String str, int i, int i2) {
        if (str != null) {
            initWorkers(str, i, i2);
            return;
        }
        cleanup();
        for (WorkerGroup<I, O> workerGroup : this.workerGroups.values()) {
            workerGroup.configureWorkers(i, i2);
            doScaleWorker(workerGroup);
        }
        log();
    }

    private void doScaleWorker(WorkerGroup<I, O> workerGroup) {
        int minWorkers = workerGroup.getMinWorkers();
        ArrayList arrayList = new ArrayList();
        for (WorkerThread<I, O> workerThread : workerGroup.getWorkers()) {
            if (workerThread.isFixPoolThread()) {
                arrayList.add(workerThread);
            }
        }
        int size = arrayList.size();
        if (size < minWorkers) {
            workerGroup.addThreads(minWorkers - size, true);
        } else {
            arrayList.subList(minWorkers, size).forEach(workerThread2 -> {
                workerThread2.shutdown(WorkerState.WORKER_SCALED_DOWN);
            });
        }
    }

    public void shutdownWorkers() {
        synchronized (this.model) {
            List<WorkerThread<I, O>> workers = getWorkers();
            Iterator<WorkerThread<I, O>> it = workers.iterator();
            while (it.hasNext()) {
                it.next().shutdown(WorkerState.WORKER_SCALED_DOWN);
            }
            workers.clear();
        }
    }

    public void cleanup() {
        Iterator<WorkerGroup<I, O>> it = this.workerGroups.values().iterator();
        while (it.hasNext()) {
            it.next().workers.removeIf(workerThread -> {
                return workerThread.getState() == WorkerState.WORKER_STOPPED || workerThread.getState() == WorkerState.WORKER_ERROR;
            });
        }
    }

    public void shutdown() {
        this.model.close();
        Iterator<WorkerGroup<I, O>> it = this.workerGroups.values().iterator();
        while (it.hasNext()) {
            Iterator<WorkerThread<I, O>> it2 = it.next().workers.iterator();
            while (it2.hasNext()) {
                it2.next().shutdown(WorkerState.WORKER_STOPPED);
            }
        }
        this.workerGroups.clear();
        Iterator<WorkerJob<I, O>> it3 = this.jobQueue.iterator();
        while (it3.hasNext()) {
            it3.next().getFuture().cancel(true);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void addThreads() {
        ArrayList arrayList = new ArrayList(this.workerGroups.values());
        if (arrayList.isEmpty()) {
            logger.warn("No worker pool available.");
            return;
        }
        arrayList.sort(Comparator.comparingInt(workerGroup -> {
            return workerGroup.getMaxWorkers() - workerGroup.getMinWorkers();
        }));
        WorkerGroup workerGroup2 = (WorkerGroup) arrayList.get(arrayList.size() - 1);
        if (workerGroup2.getMaxWorkers() > workerGroup2.workers.size()) {
            workerGroup2.addThreads(1, false);
        }
    }

    private void log() {
        if (logger.isDebugEnabled()) {
            StringBuffer stringBuffer = new StringBuffer();
            getWorkers().forEach(workerThread -> {
                stringBuffer.append(workerThread.getWorkerId());
                if (workerThread.isFixPoolThread()) {
                    stringBuffer.append("-fixedPool\n");
                } else {
                    stringBuffer.append("-tmpPool\n");
                }
            });
            logger.debug("worker pool for model {}:\n {}", this.model, stringBuffer);
        }
    }
}
