package chat.octet.model;

import chat.octet.model.beans.CompletionResult;
import chat.octet.model.beans.Status;
import chat.octet.model.beans.Token;
import chat.octet.model.enums.FinishReason;
import chat.octet.model.exceptions.DecodeException;
import chat.octet.model.exceptions.GenerationException;
import chat.octet.model.parameters.GenerateParameter;
import java.nio.charset.StandardCharsets;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Spliterator;
import javax.annotation.Nonnull;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:chat/octet/model/Generator.class */
public class Generator implements Iterable<Token> {
    private static final Logger log = LoggerFactory.getLogger(Generator.class);
    private Inference inference;
    private Status chatStatus;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:chat/octet/model/Generator$Inference.class */
    public static class Inference implements Iterator<Token> {
        private final GenerateParameter generateParams;
        private final Status status;
        private final byte[] multiByteTokenBuffer;
        private int multiByteTokenLength;
        private int multiByteTokenIndex;
        private boolean finished;
        private final int maxNewTokenSize;
        private final int contextSize;

        protected Status getStatus() {
            return this.status;
        }

        protected Inference(GenerateParameter generateParameter, String str, Status status) {
            this.finished = false;
            this.generateParams = generateParameter;
            this.multiByteTokenBuffer = new byte[8];
            this.contextSize = LlamaService.getContextSize();
            this.status = status == null ? new Status() : new Status(status);
            int[] iArr = StringUtils.isNotBlank(str) ? LlamaService.tokenize(str, true, true) : new int[]{LlamaService.getTokenBOS()};
            if (iArr.length >= this.contextSize) {
                throw new IllegalArgumentException(MessageFormat.format("Requested tokens ({0}) exceed context window of {1}.", Integer.valueOf(iArr.length), Integer.valueOf(this.contextSize)));
            }
            if (generateParameter.isVerbosePrompt()) {
                Generator.log.info("Print prompt text:\n{}", str);
            }
            this.status.appendTokens(iArr);
            this.maxNewTokenSize = generateParameter.getMaxNewTokenSize() <= 0 ? this.contextSize - this.status.getInputLength() : generateParameter.getMaxNewTokenSize();
            if (StringUtils.isNotBlank(generateParameter.getGrammarRules()) && !LlamaService.loadLlamaGrammar(generateParameter.getGrammarRules())) {
                Generator.log.error("Grammar rule parsing failed, Please check the grammar rule format.");
            }
            Generator.log.debug("Generate starting, input token size: {}, past token size: {}.", Integer.valueOf(iArr.length), Integer.valueOf(this.status.getPastTokenSize()));
            decodePrompt();
        }

        protected Inference(GenerateParameter generateParameter, String str) {
            this(generateParameter, str, null);
        }

        private void decodePrompt() {
            int batchDecode = LlamaService.batchDecode(this.status.getId(), this.status.getInputIds(), this.status.getInputLength(), this.status.getPastTokenSize());
            if (batchDecode != 0) {
                throw new DecodeException(MessageFormat.format("Failed to decode, return code: {0}.", Integer.valueOf(batchDecode)));
            }
            int inputLength = this.status.getInputLength() - this.status.getPastTokenSize();
            this.status.addPastTokensSize(inputLength);
            Generator.log.debug("Batch decode prompt completed, decode token size: {}, sequence id: {}.", Integer.valueOf(inputLength), Integer.valueOf(this.status.getId()));
        }

        private boolean breakOrContinue(Token token, float[] fArr) {
            if (token.getId() == LlamaService.getTokenEOS()) {
                token.updateFinishReason(FinishReason.FINISHED);
                return true;
            }
            if (this.generateParams.getStoppingCriteriaList() != null && this.generateParams.getStoppingCriteriaList().criteria(this.status.getInputIds(), fArr, new Object[0])) {
                token.updateFinishReason(FinishReason.STOP);
                return true;
            }
            if (this.status.getInputLength() >= this.contextSize) {
                token.updateFinishReason(FinishReason.TRUNCATED);
                Generator.log.warn("Context size has been exceeded. Truncate and reset the context cache, sequence id: {}.", Integer.valueOf(this.status.getId()));
                return true;
            }
            if (this.status.getGenerateTokens().size() < this.maxNewTokenSize) {
                return false;
            }
            token.updateFinishReason(FinishReason.LENGTH);
            return true;
        }

        private String tokenToText(int i) {
            byte[] bArr = new byte[64];
            int i2 = LlamaService.tokenToPiece(i, bArr, bArr.length);
            byte b = bArr[0];
            if (i2 != 1 || Character.isValidCodePoint(b)) {
                return new String(bArr, 0, i2, StandardCharsets.UTF_8);
            }
            if (this.multiByteTokenLength == 0) {
                this.multiByteTokenLength = TokenDecoder.getUtf8ByteLength(b);
            }
            this.multiByteTokenBuffer[this.multiByteTokenIndex] = b;
            this.multiByteTokenIndex++;
            if (this.multiByteTokenIndex != this.multiByteTokenLength) {
                return "";
            }
            String str = new String(this.multiByteTokenBuffer, 0, this.multiByteTokenLength, StandardCharsets.UTF_8);
            this.multiByteTokenIndex = 0;
            this.multiByteTokenLength = 0;
            Arrays.fill(this.multiByteTokenBuffer, (byte) 0);
            return str;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return !this.finished;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public Token next() {
            float[] logits = LlamaService.getLogits(this.status.getLogitsIndex());
            if (this.generateParams.getLogitsProcessorList() != null) {
                logits = this.generateParams.getLogitsProcessorList().processor(this.status.getInputIds(), logits, new Object[0]);
            }
            int[] iArr = null;
            if (this.generateParams.getLastTokensSize() != 0) {
                iArr = this.status.subInputIds(Math.max(0, this.status.getInputLength() - this.generateParams.getLastTokensSize()));
            }
            int sampling = LlamaService.sampling(logits, iArr, this.generateParams.getLastTokensSize(), this.generateParams.getRepeatPenalty(), this.generateParams.getFrequencyPenalty(), this.generateParams.getPresencePenalty(), this.generateParams.isPenalizeNl(), this.generateParams.getMirostatMode().ordinal(), this.generateParams.getMirostatTAU(), this.generateParams.getMirostatETA(), this.generateParams.getTemperature(), this.generateParams.getTopK(), this.generateParams.getTopP(), this.generateParams.getTsf(), this.generateParams.getTypical(), this.status.getId(), this.status.getPastTokenSize());
            Token token = new Token(sampling, LlamaService.getLlamaTokenType(sampling), tokenToText(sampling));
            this.status.appendNextToken(token);
            this.finished = breakOrContinue(token, logits);
            return token;
        }

        public void clearCache() {
            this.status.reset();
            Generator.log.debug("Cache clear completed, sequence id: {}.", Integer.valueOf(this.status.getId()));
        }
    }

    public Generator(GenerateParameter generateParameter, String str, Status status) {
        this.chatStatus = status;
        this.inference = new Inference(generateParameter, str, status);
    }

    public Generator(GenerateParameter generateParameter, String str) {
        new Generator(generateParameter, str, null);
    }

    @Override // java.lang.Iterable
    @Nonnull
    public Iterator<Token> iterator() {
        return this.inference;
    }

    @Override // java.lang.Iterable
    public Spliterator<Token> spliterator() {
        throw new RuntimeException("Unsupported operation.");
    }

    public void output() {
        try {
            try {
                Iterator<Token> it = iterator();
                while (it.hasNext()) {
                    System.out.print(it.next().getText());
                }
            } catch (Exception e) {
                throw new GenerationException("Generate next token error ", e);
            }
        } finally {
            if (this.chatStatus != null) {
                this.chatStatus.copyToStatus(this.inference.getStatus());
            } else {
                this.inference.clearCache();
            }
        }
    }

    public CompletionResult result() {
        StringBuilder sb = new StringBuilder();
        FinishReason finishReason = FinishReason.UNKNOWN;
        while (true) {
            FinishReason finishReason2 = finishReason;
            if (!this.inference.hasNext()) {
                return CompletionResult.builder().content(sb.toString()).finishReason(finishReason2).build();
            }
            Token next = this.inference.next();
            sb.append(next.getText());
            finishReason = next.getFinishReason();
        }
    }
}
