package com.github.handong0123.tensorflow.deploy.provider;

import com.github.handong0123.tensorflow.deploy.session.entity.ModelInput;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelOutput;
import com.github.handong0123.tensorflow.deploy.session.model.TensorflowModelService;
import com.github.handong0123.tensorflow.deploy.session.model.TensorflowModelServiceImpl;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/handong0123/tensorflow/deploy/provider/TensorflowProvider.class */
public class TensorflowProvider {
    private static final Logger LOG = LoggerFactory.getLogger(TensorflowProvider.class);
    private static final String DEFAULT_GPU_ID = "-1";
    private static final int DEFAULT_TIMEOUT = 300000;
    private static final int DEFAULT_QUEUE_SIZE = 5000;
    private static final int DEFAULT_THREAD_NUM = 3;
    private int timeout;
    private List<TensorflowProviderThread> tensorflowProviderThreads;
    private final LinkedBlockingQueue<ModelInput> queue;
    private final ConcurrentHashMap<Object, ModelOutput> result;
    private ExecutorService executorService;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/github/handong0123/tensorflow/deploy/provider/TensorflowProvider$TensorflowProviderThread.class */
    public class TensorflowProviderThread extends Thread {
        private TensorflowModelService modelService;
        private int timeout;

        TensorflowProviderThread(TensorflowModelService tensorflowModelService, int i) {
            this.modelService = tensorflowModelService;
            this.timeout = i;
        }

        TensorflowModelService getModelService() {
            return this.modelService;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                try {
                    ModelInput modelInput = (ModelInput) TensorflowProvider.this.queue.take();
                    long currentTimeMillis = System.currentTimeMillis();
                    ModelOutput predict = this.modelService.predict(modelInput);
                    if (null != predict && System.currentTimeMillis() - currentTimeMillis <= this.timeout) {
                        TensorflowProvider.this.result.put(modelInput.getReqUuid(), predict);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                    return;
                }
            }
        }
    }

    public TensorflowProvider(String str, String str2) {
        this(DEFAULT_THREAD_NUM, str, str2, null);
    }

    public TensorflowProvider(int i, String str, String str2) {
        this(i, str, str2, null);
    }

    public TensorflowProvider(int i, String str, String str2, String str3) {
        this(i, str, str2, str3, DEFAULT_TIMEOUT);
    }

    public TensorflowProvider(int i, String str, String str2, String str3, int i2) {
        this.timeout = i2;
        this.queue = new LinkedBlockingQueue<>(DEFAULT_QUEUE_SIZE);
        this.result = new ConcurrentHashMap<>();
        this.tensorflowProviderThreads = new ArrayList();
        this.executorService = new ThreadPoolExecutor(10, 20, 0L, TimeUnit.SECONDS, new LinkedBlockingQueue(1), new ThreadFactoryBuilder().setNameFormat("tensorflow-provider-pool-%d").build());
        String[] split = StringUtils.isBlank(str3) ? new String[0] : str3.split(",");
        boolean z = 0 == split.length;
        int length = z ? i <= 0 ? DEFAULT_THREAD_NUM : i : split.length;
        for (int i3 = 0; i3 < length; i3++) {
            TensorflowProviderThread tensorflowProviderThread = new TensorflowProviderThread(new TensorflowModelServiceImpl(str, str2, z ? DEFAULT_GPU_ID : split[i3]), i2);
            this.executorService.execute(tensorflowProviderThread);
            this.tensorflowProviderThreads.add(tensorflowProviderThread);
        }
    }

    private boolean putMessage(ModelInput modelInput) {
        modelInput.setReqUuid(new Object());
        return this.queue.offer(modelInput);
    }

    private ModelOutput getMessage(Object obj) throws InterruptedException {
        ModelOutput modelOutput;
        long currentTimeMillis = System.currentTimeMillis();
        while (true) {
            modelOutput = this.result.get(obj);
            if (modelOutput != null) {
                this.result.remove(obj);
                break;
            }
            if (System.currentTimeMillis() - currentTimeMillis > this.timeout) {
                LOG.info("Get Message Timeout");
                break;
            }
            Thread.sleep(1L);
        }
        return modelOutput;
    }

    public ModelOutput predict(ModelInput modelInput) throws InterruptedException {
        if (putMessage(modelInput)) {
            return getMessage(modelInput.getReqUuid());
        }
        LOG.error("Put message failed, maybe queue is full");
        return null;
    }

    public void modelReload() {
        this.tensorflowProviderThreads.forEach(tensorflowProviderThread -> {
            tensorflowProviderThread.getModelService().modelReload();
        });
    }
}
