package ai.djl.serving.wlm;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/djl/serving/wlm/WorkerThread.class */
public final class WorkerThread implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
    private String workerName;
    private Predictor<Input, Output> predictor;
    private AtomicBoolean running;
    private BatchAggregator aggregator;
    private int gpuId;
    private AtomicReference<Thread> currentThread;
    private WorkerState state;
    private int workerId;
    private long startTime;
    private boolean fixPoolThread;

    /* loaded from: input_file:ai/djl/serving/wlm/WorkerThread$Builder.class */
    public static class Builder {
        private ModelInfo model;
        private BatchAggregator aggregator;
        private LinkedBlockingDeque<Job> jobQueue;
        private int gpuId = -1;
        private boolean fixPoolThread = true;
        private GpuAssignmentStrategy gpuAssignmentStrategy;

        Builder() {
        }

        protected Builder self() {
            return this;
        }

        protected void preBuildProcessing() {
            if (this.aggregator == null) {
                if (this.fixPoolThread) {
                    this.aggregator = new PermanentBatchAggregator(this.model, this.jobQueue);
                } else {
                    this.aggregator = new TemporaryBatchAggregator(this.model, this.jobQueue);
                }
            }
            if (this.gpuAssignmentStrategy != null) {
                this.gpuId = this.gpuAssignmentStrategy.nextGpuId();
            }
        }

        protected void validate() {
            if (this.model == null) {
                throw new IllegalArgumentException("model must not be null");
            }
            if (this.jobQueue == null && this.aggregator == null) {
                throw new IllegalArgumentException("one of jobQueue or BatchAggregator have to be set.");
            }
        }

        public WorkerThread build() {
            validate();
            preBuildProcessing();
            return new WorkerThread(this);
        }

        public Builder setModel(ModelInfo modelInfo) {
            this.model = modelInfo;
            return self();
        }

        public Builder optAggregator(BatchAggregator batchAggregator) {
            this.aggregator = batchAggregator;
            return self();
        }

        public Builder setJobQueue(LinkedBlockingDeque<Job> linkedBlockingDeque) {
            this.jobQueue = linkedBlockingDeque;
            return self();
        }

        public Builder optGpuId(int i) {
            this.gpuId = i;
            return self();
        }

        public Builder optFixPoolThread(boolean z) {
            this.fixPoolThread = z;
            return self();
        }

        public Builder optGpuAssignmentStrategy(GpuAssignmentStrategy gpuAssignmentStrategy) {
            this.gpuAssignmentStrategy = gpuAssignmentStrategy;
            return self();
        }
    }

    private WorkerThread(Builder builder) {
        this.running = new AtomicBoolean(true);
        this.currentThread = new AtomicReference<>();
        this.workerName = buildWorkerName(builder.model);
        this.aggregator = builder.aggregator;
        this.gpuId = builder.gpuId;
        this.workerId = new WorkerIdGenerator().generate();
        this.startTime = System.currentTimeMillis();
        this.predictor = builder.model.getModel().newPredictor();
        this.fixPoolThread = builder.fixPoolThread;
    }

    @Override // java.lang.Runnable
    public void run() {
        Thread currentThread = Thread.currentThread();
        currentThread.setName(this.workerName);
        this.currentThread.set(currentThread);
        this.state = WorkerState.WORKER_STARTED;
        boolean z = false;
        while (isRunning() && !this.aggregator.isFinished()) {
            try {
                try {
                    try {
                        List<Input> request = this.aggregator.getRequest();
                        if (request != null && !request.isEmpty()) {
                            try {
                                this.aggregator.sendResponse(this.predictor.batchPredict(request));
                            } catch (TranslateException e) {
                                logger.warn("Failed to predict", e);
                                this.aggregator.sendError();
                            }
                        }
                        z = false;
                    } catch (InterruptedException e2) {
                        logger.debug("Shutting down the thread .. Scaling down.");
                        logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                        this.currentThread.set(null);
                        shutdown(WorkerState.WORKER_STOPPED);
                        if (z) {
                            this.aggregator.sendError();
                            return;
                        }
                        return;
                    }
                } catch (Throwable th) {
                    logger.error("Server error", th);
                    logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                    this.currentThread.set(null);
                    shutdown(WorkerState.WORKER_STOPPED);
                    if (z) {
                        this.aggregator.sendError();
                        return;
                    }
                    return;
                }
            } catch (Throwable th2) {
                logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                this.currentThread.set(null);
                shutdown(WorkerState.WORKER_STOPPED);
                if (z) {
                    this.aggregator.sendError();
                }
                throw th2;
            }
        }
        logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
        this.currentThread.set(null);
        shutdown(WorkerState.WORKER_STOPPED);
        if (z) {
            this.aggregator.sendError();
        }
    }

    public int getWorkerId() {
        return this.workerId;
    }

    public boolean isRunning() {
        return this.running.get();
    }

    public int getGpuId() {
        return this.gpuId;
    }

    public long getStartTime() {
        return this.startTime;
    }

    public WorkerState getState() {
        return this.state;
    }

    public void shutdown(WorkerState workerState) {
        this.running.set(false);
        setState(workerState);
        Thread andSet = this.currentThread.getAndSet(null);
        if (andSet != null) {
            andSet.interrupt();
            this.aggregator.sendError();
        }
        this.predictor.close();
    }

    private String buildWorkerName(ModelInfo modelInfo) {
        String modelName = modelInfo.getModelName();
        if (modelName.length() > 25) {
            modelName = modelName.substring(0, 25);
        }
        return "W-" + modelName + '-' + this.workerId;
    }

    void setState(WorkerState workerState) {
        logger.debug("{} State change {} -> {}", new Object[]{this.workerName, this.state, workerState});
        if (this.state != WorkerState.WORKER_SCALED_DOWN) {
            this.state = workerState;
        }
    }

    public boolean isFixPoolThread() {
        return this.fixPoolThread;
    }

    public static Builder builder() {
        return new Builder();
    }
}
