package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingResult;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import com.knuddels.jtokkit.api.IntArrayList;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import opennlp.tools.parser.Parse;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/knuddels/jtokkit/GptBytePairEncoding.class */
public class GptBytePairEncoding implements Encoding {
    final TokenEncoder encoder;
    private final String name;
    private final Pattern pattern;
    private final SpecialEncoder specialEncoder;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/knuddels/jtokkit/GptBytePairEncoding$InternalResult.class */
    public static final class InternalResult {
        private final IntArrayList tokens;
        private final boolean truncated;
        private final int tokenCount;
        private final int lastProcessedCharacterIndex;

        private InternalResult(IntArrayList intArrayList, int i, boolean z, int i2) {
            this.tokens = intArrayList;
            this.truncated = z;
            this.tokenCount = i < 0 ? intArrayList.size() : i;
            this.lastProcessedCharacterIndex = i2;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public EncodingResult toEncodingResult() {
            if (this.tokens.size() != this.tokenCount) {
                throw new IllegalStateException("Token count does not match token list size (tokenCount=" + this.tokenCount + ", tokens size=" + this.tokens.size() + Parse.BRACKET_RRB);
            }
            return new EncodingResult(this.tokens, this.truncated, this.lastProcessedCharacterIndex);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int toTokenCount() {
            return this.tokenCount;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public GptBytePairEncoding(GptBytePairEncodingParams gptBytePairEncodingParams) {
        this.name = gptBytePairEncodingParams.getName();
        this.pattern = gptBytePairEncodingParams.getPattern();
        this.encoder = new TokenEncoder(gptBytePairEncodingParams.getEncoder());
        this.specialEncoder = new SpecialEncoder(gptBytePairEncodingParams.getSpecialTokensEncoder());
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public IntArrayList encode(String str) {
        return encode(str, Integer.MAX_VALUE).getTokens();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public EncodingResult encode(String str, int i) {
        return encodeInternal(str, i, true).toEncodingResult();
    }

    private InternalResult encodeInternal(String str, int i, boolean z) {
        if (str == null) {
            return new InternalResult(new IntArrayList(0), -1, false, -1);
        }
        this.specialEncoder.checkForSpecialTokens(str);
        return encodeOrdinaryInternal(str, i, z);
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public IntArrayList encodeOrdinary(String str) {
        return encodeOrdinary(str, Integer.MAX_VALUE).getTokens();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public EncodingResult encodeOrdinary(String str, int i) {
        return encodeOrdinaryInternal(str, i, true).toEncodingResult();
    }

    private InternalResult encodeOrdinaryInternal(String str, int i, boolean z) {
        if (str == null) {
            return new InternalResult(new IntArrayList(0), -1, false, -1);
        }
        IntArrayList intArrayList = new IntArrayList();
        int encodeOrdinaryInternal = encodeOrdinaryInternal(str, i, z, intArrayList);
        if (z && i != Integer.MAX_VALUE) {
            for (int i2 = 0; i2 <= intArrayList.size(); i2++) {
                int size = intArrayList.size() - i2;
                IntArrayList intArrayList2 = new IntArrayList(size);
                for (int i3 = 0; i3 < size; i3++) {
                    intArrayList2.add(intArrayList.get(i3));
                }
                String decode = decode(intArrayList2);
                if (str.startsWith(decode)) {
                    return new InternalResult(intArrayList2, -1, str.length() > decode.length(), decode.length() - 1);
                }
            }
        }
        return new InternalResult(intArrayList, encodeOrdinaryInternal, false, str.length() - 1);
    }

    int encodeOrdinaryInternal(String str, int i, boolean z, IntArrayList intArrayList) {
        int i2 = 0;
        IntArrayList intArrayList2 = new IntArrayList();
        Matcher matcher = this.pattern.matcher(str);
        while (i2 < i && matcher.find()) {
            i2 += this.encoder.addTokensAndGetCount(i, z, matcher.group().getBytes(StandardCharsets.UTF_8), intArrayList, intArrayList2);
        }
        return i2;
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public int countTokens(String str) {
        return encodeInternal(str, Integer.MAX_VALUE, false).toTokenCount();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public int countTokensOrdinary(String str) {
        return encodeOrdinaryInternal(str, Integer.MAX_VALUE, false).toTokenCount();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public String decode(IntArrayList intArrayList) {
        return new String(decodeBytes(intArrayList), StandardCharsets.UTF_8);
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public byte[] decodeBytes(IntArrayList intArrayList) {
        ByteArrayList byteArrayList = new ByteArrayList(10 * intArrayList.size());
        for (int i = 0; i < intArrayList.size(); i++) {
            for (byte b : decodeToken(intArrayList.get(i))) {
                byteArrayList.add(b);
            }
        }
        return byteArrayList.toArray();
    }

    @Override // com.knuddels.jtokkit.api.Encoding
    public String getName() {
        return this.name;
    }

    private byte[] decodeToken(int i) {
        return (byte[]) Objects.requireNonNull(this.encoder.decodeToken(i, this.specialEncoder), "Unknown token for decoding: " + i);
    }
}
