package chat.octet.model;

import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.LlamaContextParams;
import chat.octet.model.beans.LlamaModelParams;
import chat.octet.model.beans.Status;
import chat.octet.model.enums.ModelType;
import chat.octet.model.exceptions.ModelException;
import chat.octet.model.parameters.GenerateParameter;
import chat.octet.model.parameters.ModelParameter;
import chat.octet.model.utils.PromptBuilder;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/model/Model.class */
public class Model implements AutoCloseable {
    private static final Logger log = LoggerFactory.getLogger(Model.class);
    private final ModelParameter modelParams;
    private final String modelName;
    private final String modelType;
    private final int lastTokensSize;
    private final Map<String, Status> chatStatus;

    public Model(String str) {
        this(ModelParameter.builder().modelPath(str).build());
    }

    public Model(ModelParameter modelParameter) {
        this.chatStatus = Maps.newConcurrentMap();
        Preconditions.checkNotNull(modelParameter, "Model parameters cannot be null");
        Preconditions.checkNotNull(modelParameter.getModelPath(), "Model file path cannot be null");
        if (!Files.exists(new File(modelParameter.getModelPath()).toPath(), new LinkOption[0])) {
            throw new ModelException("Model file is not exists, please check the file path");
        }
        this.modelParams = modelParameter;
        this.modelName = modelParameter.getModelName();
        this.modelType = modelParameter.getModelType();
        Preconditions.checkNotNull(this.modelType, "Model type cannot be null");
        this.lastTokensSize = modelParameter.getLastTokensSize() < 0 ? LlamaService.getContextSize() : modelParameter.getLastTokensSize();
        LlamaService.loadLlamaModelFromFile(modelParameter.getModelPath(), getLlamaModelParameters(modelParameter));
        LlamaService.createNewContextWithModel(getLlamaContextParameters(modelParameter));
        if (StringUtils.isNotBlank(modelParameter.getLoraPath())) {
            if (!Files.exists(new File(modelParameter.getLoraPath()).toPath(), new LinkOption[0])) {
                throw new ModelException("Lora model file is not exists, please check the file path");
            }
            if (LlamaService.loadLoraModelFromFile(modelParameter.getLoraPath(), modelParameter.getLoraScale(), modelParameter.getLoraBase(), modelParameter.getThreads()) != 0) {
                throw new ModelException(String.format("Failed to apply LoRA from lora path: %s to base path: %s", modelParameter.getLoraPath(), modelParameter.getLoraBase()));
            }
        }
        if (modelParameter.isVerbose()) {
            log.info("system info: {}", LlamaService.getSystemInfo());
        }
        log.info("model parameters: {}", modelParameter);
    }

    private LlamaModelParams getLlamaModelParameters(ModelParameter modelParameter) {
        LlamaModelParams llamaModelDefaultParams = LlamaService.getLlamaModelDefaultParams();
        llamaModelDefaultParams.gpuLayers = modelParameter.getGpuLayers();
        llamaModelDefaultParams.splitMode = modelParameter.getSplitMode();
        llamaModelDefaultParams.vocabOnly = modelParameter.isVocabOnly();
        if ((StringUtils.isBlank(modelParameter.getLoraPath()) && modelParameter.isMmap()) && LlamaService.isMmapSupported()) {
            llamaModelDefaultParams.mmap = true;
        }
        if (modelParameter.isMlock() && LlamaService.isMlockSupported()) {
            llamaModelDefaultParams.mlock = true;
        }
        if (modelParameter.getMainGpu() != null) {
            llamaModelDefaultParams.mainGpu = modelParameter.getMainGpu().intValue();
        }
        if (modelParameter.getTensorSplit() != null) {
            llamaModelDefaultParams.tensorSplit = modelParameter.getTensorSplit();
        }
        return llamaModelDefaultParams;
    }

    private LlamaContextParams getLlamaContextParameters(ModelParameter modelParameter) {
        LlamaContextParams llamaContextDefaultParams = LlamaService.getLlamaContextDefaultParams();
        llamaContextDefaultParams.seed = modelParameter.getSeed();
        llamaContextDefaultParams.ctx = modelParameter.getContextSize();
        llamaContextDefaultParams.batch = modelParameter.getBatchSize();
        llamaContextDefaultParams.threads = modelParameter.getThreads();
        llamaContextDefaultParams.threadsBatch = modelParameter.getThreadsBatch() == -1 ? modelParameter.getThreads() : modelParameter.getThreadsBatch();
        llamaContextDefaultParams.ropeScalingType = modelParameter.getRopeScalingType();
        llamaContextDefaultParams.yarnExtFactor = modelParameter.getYarnExtFactor();
        llamaContextDefaultParams.yarnAttnFactor = modelParameter.getYarnAttnFactor();
        llamaContextDefaultParams.yarnBetaFast = modelParameter.getYarnBetaFast();
        llamaContextDefaultParams.yarnBetaSlow = modelParameter.getYarnBetaSlow();
        llamaContextDefaultParams.yarnOrigCtx = modelParameter.getYarnOrigCtx();
        llamaContextDefaultParams.ropeFreqBase = modelParameter.getRopeFreqBase();
        llamaContextDefaultParams.ropeFreqScale = modelParameter.getRopeFreqScale();
        llamaContextDefaultParams.mulMatQ = modelParameter.isMulMatQ();
        llamaContextDefaultParams.logitsAll = modelParameter.isLogitsAll();
        llamaContextDefaultParams.embedding = modelParameter.isEmbedding();
        llamaContextDefaultParams.offloadKqv = modelParameter.isOffloadKqv();
        return llamaContextDefaultParams;
    }

    public void removeChatStatus(String str) {
        if (this.chatStatus.containsKey(str)) {
            Status remove = this.chatStatus.remove(str);
            if (remove != null) {
                remove.reset();
            }
            log.info("Removed chat session, User: {}.", str);
        }
    }

    public void removeAllChatStatus() {
        int size = this.chatStatus.size();
        if (size > 0) {
            this.chatStatus.keySet().forEach(this::removeChatStatus);
            log.info("Removed all chat sessions, size: {}.", Integer.valueOf(size));
        }
    }

    public CompletionResult completions(String str) {
        return completions(GenerateParameter.builder().build(), str);
    }

    public CompletionResult completions(GenerateParameter generateParameter, String str) {
        return generate(generateParameter, str).result();
    }

    public Generator generate(String str) {
        return generate(GenerateParameter.builder().build(), str);
    }

    public Generator generate(GenerateParameter generateParameter, String str) {
        Preconditions.checkNotNull(generateParameter, "Generate parameter cannot be null");
        Preconditions.checkNotNull(str, "Text cannot be null");
        generateParameter.setLastTokensSize(this.lastTokensSize);
        return new Generator(generateParameter, str);
    }

    public CompletionResult chatCompletions(String str) {
        return chatCompletions(GenerateParameter.builder().build(), null, str);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParameter, String str) {
        return chatCompletions(generateParameter, null, str);
    }

    public CompletionResult chatCompletions(GenerateParameter generateParameter, String str, String str2) {
        return chat(generateParameter, str, str2).result();
    }

    public Generator chat(String str) {
        return chat(GenerateParameter.builder().build(), null, str);
    }

    public Generator chat(String str, String str2) {
        return chat(GenerateParameter.builder().build(), str, str2);
    }

    public Generator chat(GenerateParameter generateParameter, String str) {
        return chat(generateParameter, null, str);
    }

    public Generator chat(GenerateParameter generateParameter, String str, String str2) {
        Preconditions.checkNotNull(generateParameter, "Generate parameter cannot be null");
        Preconditions.checkNotNull(str2, "Question cannot be null");
        Preconditions.checkNotNull(generateParameter.getUser(), "User id cannot be null");
        generateParameter.setLastTokensSize(this.lastTokensSize);
        if (!this.chatStatus.containsKey(generateParameter.getUser())) {
            Status status = new Status();
            this.chatStatus.put(generateParameter.getUser(), status);
            log.debug("Create new chat session, User: {} id: {}, chat session cache size: {}.", new Object[]{generateParameter.getUser(), Integer.valueOf(status.getId()), Integer.valueOf(this.chatStatus.size())});
        }
        Status status2 = this.chatStatus.get(generateParameter.getUser());
        if (StringUtils.isNotBlank(str) && str.equals(status2.getInitialSystemPrompt())) {
            str = null;
        }
        if (StringUtils.isNotBlank(str) && StringUtils.isBlank(status2.getInitialSystemPrompt())) {
            status2.setInitialSystemPrompt(str);
        }
        return new Generator(generateParameter, PromptBuilder.format(ModelType.valueOf(this.modelType.toUpperCase()), str, str2), status2);
    }

    public void metrics() {
        if (this.modelParams.isVerbose()) {
            log.info("Metrics: {}", LlamaService.getSamplingMetrics(true).toString());
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        removeAllChatStatus();
        LlamaService.release();
        LlamaService.llamaBackendFree();
    }

    public String toString() {
        return "LlamaModel (modelParams=" + this.modelParams + ')';
    }

    public ModelParameter getModelParams() {
        return this.modelParams;
    }

    public String getModelName() {
        return this.modelName;
    }

    public String getModelType() {
        return this.modelType;
    }
}
