package chat.octet.model;

import chat.octet.model.beans.ChatMessage;
import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.LlamaContextParams;
import chat.octet.model.beans.LlamaModelParams;
import chat.octet.model.beans.Status;
import chat.octet.model.components.criteria.impl.StoppingWordCriteria;
import chat.octet.model.components.processor.impl.CustomBiasLogitsProcessor;
import chat.octet.model.components.prompt.ChatTemplateFormatter;
import chat.octet.model.components.prompt.DefaultChatTemplateFormatter;
import chat.octet.model.exceptions.ModelException;
import chat.octet.model.functions.Function;
import chat.octet.model.parameters.GenerateParameter;
import chat.octet.model.parameters.ModelParameter;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/model/Model.class */
public class Model implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(Model.class);
    private final ModelParameter modelParams;
    private final String modelName;
    private final String modelType;
    private final ChatTemplateFormatter chatFormatter;
    private final Map<String, Status> chatStatus;

    public Model(String str) {
        this(ModelParameter.builder().modelPath(str).build());
    }

    public Model(ModelParameter modelParameter) {
        this.chatStatus = Maps.newConcurrentMap();
        Preconditions.checkNotNull(modelParameter, "Model parameters cannot be null");
        Preconditions.checkNotNull(modelParameter.getModelPath(), "Model file path cannot be null");
        if (!Files.exists(new File(modelParameter.getModelPath()).toPath(), new LinkOption[0])) {
            throw new ModelException("Model file is not exists, please check the file path");
        }
        this.modelParams = modelParameter;
        LlamaService.loadLlamaModelFromFile(modelParameter.getModelPath(), getLlamaModelParameters(modelParameter), getLlamaContextParameters(modelParameter));
        if (StringUtils.isNotBlank(modelParameter.getLoraPath())) {
            if (!Files.exists(new File(modelParameter.getLoraPath()).toPath(), new LinkOption[0])) {
                throw new ModelException("Lora model file is not exists, please check the file path");
            }
            if (LlamaService.loadLoraModelFromFile(modelParameter.getLoraPath(), modelParameter.getLoraScale(), modelParameter.getLoraBase(), modelParameter.getThreads()) != 0) {
                throw new ModelException(String.format("Failed to apply LoRA from lora path: %s to base path: %s", modelParameter.getLoraPath(), modelParameter.getLoraBase()));
            }
        }
        this.modelName = LlamaService.llamaModelMeta("general.name");
        this.modelType = LlamaService.llamaModelMeta("general.architecture");
        this.chatFormatter = (ChatTemplateFormatter) Optional.ofNullable(modelParameter.getChatTemplateFormatter()).orElse(new DefaultChatTemplateFormatter(this.modelType, LlamaService.llamaModelMeta("tokenizer.chat_template")));
        log.info(LlamaService.getSystemInfo());
        log.info(toString());
        log.info("Model loaded successfully.");
    }

    private LlamaModelParams getLlamaModelParameters(ModelParameter modelParameter) {
        LlamaModelParams llamaModelDefaultParams = LlamaService.getLlamaModelDefaultParams();
        llamaModelDefaultParams.gpuLayers = modelParameter.getGpuLayers();
        llamaModelDefaultParams.splitMode = modelParameter.getSplitMode();
        llamaModelDefaultParams.mainGpu = modelParameter.getMainGpu();
        if (modelParameter.getTensorSplit() != null) {
            llamaModelDefaultParams.tensorSplit = modelParameter.getTensorSplit();
        }
        llamaModelDefaultParams.vocabOnly = modelParameter.isVocabOnly();
        if ((StringUtils.isBlank(modelParameter.getLoraPath()) && modelParameter.isMmap()) && LlamaService.isMmapSupported()) {
            llamaModelDefaultParams.mmap = true;
        }
        if (modelParameter.isMlock() && LlamaService.isMlockSupported()) {
            llamaModelDefaultParams.mlock = true;
        }
        llamaModelDefaultParams.checkTensors = modelParameter.isCheckTensors();
        llamaModelDefaultParams.numaStrategy = modelParameter.getNumaStrategy();
        return llamaModelDefaultParams;
    }

    private LlamaContextParams getLlamaContextParameters(ModelParameter modelParameter) {
        LlamaContextParams llamaContextDefaultParams = LlamaService.getLlamaContextDefaultParams();
        llamaContextDefaultParams.seed = modelParameter.getSeed();
        llamaContextDefaultParams.ctx = modelParameter.getContextSize();
        llamaContextDefaultParams.batch = modelParameter.getBatchSize();
        llamaContextDefaultParams.ubatch = modelParameter.getUbatch();
        llamaContextDefaultParams.seqMax = modelParameter.getSeqMax();
        llamaContextDefaultParams.threads = modelParameter.getThreads();
        llamaContextDefaultParams.threadsBatch = modelParameter.getThreadsBatch() == -1 ? modelParameter.getThreads() : modelParameter.getThreadsBatch();
        llamaContextDefaultParams.ropeScalingType = modelParameter.getRopeScalingType();
        llamaContextDefaultParams.poolingType = modelParameter.getPoolingType();
        llamaContextDefaultParams.yarnExtFactor = modelParameter.getYarnExtFactor();
        llamaContextDefaultParams.yarnAttnFactor = modelParameter.getYarnAttnFactor();
        llamaContextDefaultParams.yarnBetaFast = modelParameter.getYarnBetaFast();
        llamaContextDefaultParams.yarnBetaSlow = modelParameter.getYarnBetaSlow();
        llamaContextDefaultParams.yarnOrigCtx = modelParameter.getYarnOrigCtx();
        llamaContextDefaultParams.defragThold = modelParameter.getDefragThold();
        llamaContextDefaultParams.ropeFreqBase = modelParameter.getRopeFreqBase();
        llamaContextDefaultParams.ropeFreqScale = modelParameter.getRopeFreqScale();
        llamaContextDefaultParams.logitsAll = modelParameter.isLogitsAll();
        llamaContextDefaultParams.embedding = modelParameter.isEmbedding();
        llamaContextDefaultParams.offloadKqv = modelParameter.isOffloadKqv();
        llamaContextDefaultParams.flashAttn = modelParameter.isFlashAttn();
        return llamaContextDefaultParams;
    }

    public void removeChatStatus(String str) {
        String str2 = str.contains(":") ? str.split(":")[1] : "";
        for (String str3 : this.chatStatus.keySet()) {
            if (str3.equals(str) || str3.endsWith(str2)) {
                Status remove = this.chatStatus.remove(str3);
                if (remove != null) {
                    remove.reset();
                }
                log.info("Removed chat session, session: {}.", str3);
            }
        }
    }

    public void removeAllChatStatus() {
        int size = this.chatStatus.size();
        if (size > 0) {
            this.chatStatus.keySet().forEach(this::removeChatStatus);
            log.info("Removed all chat sessions, size: {}.", Integer.valueOf(size));
        }
    }

    public CompletionResult completions(String str) {
        return completions(GenerateParameter.builder().build(), str);
    }

    public CompletionResult completions(GenerateParameter generateParameter, String str) {
        return generate(generateParameter, str).result();
    }

    public Generator generate(String str) {
        return generate(GenerateParameter.builder().build(), str);
    }

    public Generator generate(GenerateParameter generateParameter, String str) {
        Preconditions.checkNotNull(generateParameter, "Generate parameter cannot be null");
        Preconditions.checkNotNull(str, "Text cannot be null");
        if (generateParameter.getLogitBias() != null && !generateParameter.getLogitBias().isEmpty()) {
            generateParameter.getLogitsProcessorList().add(new CustomBiasLogitsProcessor(generateParameter.getLogitBias(), LlamaService.getVocabSize()));
        }
        if (generateParameter.getStoppingWord() != null) {
            generateParameter.getStoppingCriteriaList().add(new StoppingWordCriteria(generateParameter.getStoppingWord()));
        }
        return new Generator(generateParameter, str);
    }

    public Generator chat(String str) {
        return chat(GenerateParameter.builder().build(), (String) null, str);
    }

    public Generator chat(String str, String str2) {
        return chat(GenerateParameter.builder().build(), str, str2);
    }

    public Generator chat(GenerateParameter generateParameter, String str, String str2) {
        Preconditions.checkNotNull(str2, "User question cannot be null");
        return StringUtils.isNotBlank(str) ? chat(generateParameter, Lists.newArrayList(new ChatMessage[]{ChatMessage.toSystem(str), ChatMessage.toUser(str2)})) : chat(generateParameter, Lists.newArrayList(new ChatMessage[]{ChatMessage.toUser(str2)}));
    }

    public Generator chat(GenerateParameter generateParameter, List<ChatMessage> list) {
        return chat(generateParameter, list, null, null);
    }

    public Generator chat(GenerateParameter generateParameter, List<ChatMessage> list, List<Function> list2) {
        return chat(generateParameter, list, list2, null);
    }

    public Generator chat(GenerateParameter generateParameter, List<ChatMessage> list, List<Function> list2, Map<String, Object> map) {
        ChatMessage orElse;
        Preconditions.checkNotNull(generateParameter, "Generate parameter cannot be null");
        Preconditions.checkNotNull(generateParameter, "Chat messages cannot be null");
        if (list.size() == 1 && ChatMessage.ChatRole.SYSTEM == list.get(0).getRole()) {
            throw new IllegalArgumentException("Chat messages cannot be only one system message");
        }
        if (generateParameter.getLogitBias() != null && !generateParameter.getLogitBias().isEmpty()) {
            generateParameter.getLogitsProcessorList().add(new CustomBiasLogitsProcessor(generateParameter.getLogitBias(), LlamaService.getVocabSize()));
        }
        if (generateParameter.getStoppingWord() != null) {
            generateParameter.getStoppingCriteriaList().add(new StoppingWordCriteria(generateParameter.getStoppingWord()));
        }
        Status status = null;
        if (generateParameter.isSessionCache()) {
            Preconditions.checkNotNull(generateParameter.getUser(), "Chat user cannot be null, please set user in generate parameter.");
            String user = StringUtils.isBlank(generateParameter.getSession()) ? generateParameter.getUser() : generateParameter.getUser() + ":" + generateParameter.getSession();
            if (!this.chatStatus.containsKey(user)) {
                Status status2 = new Status();
                this.chatStatus.put(user, status2);
                log.debug("Create new chat session, session: {} id: {}, chat session cache size: {}.", new Object[]{user, Integer.valueOf(status2.getId()), Integer.valueOf(this.chatStatus.size())});
            }
            status = this.chatStatus.get(user);
            if (generateParameter.isPromptCache() && (orElse = list.stream().filter(chatMessage -> {
                return ChatMessage.ChatRole.SYSTEM == chatMessage.getRole();
            }).findFirst().orElse(null)) != null && StringUtils.isNotBlank(orElse.getContent())) {
                if (orElse.getContent().equals(status.getSystemPromptCache())) {
                    list.remove(orElse);
                } else {
                    status.setSystemPromptCache(orElse.getContent());
                }
            }
        }
        return new Generator(generateParameter, this.chatFormatter.format(list, list2, true, map), status);
    }

    public void metrics() {
        if (this.modelParams.isVerbose()) {
            log.info("Metrics: {}", LlamaService.getSamplingMetrics(true).toString());
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        removeAllChatStatus();
        LlamaService.release();
        LlamaService.llamaBackendFree();
        log.info("Closed model and context resources.");
    }

    public String toString() {
        return "model name: " + this.modelName + ", model type: " + this.modelType + ", model parameters: " + this.modelParams;
    }

    public ModelParameter getModelParams() {
        return this.modelParams;
    }

    public String getModelName() {
        return this.modelName;
    }

    public String getModelType() {
        return this.modelType;
    }
}
