package com.eshore.tensorflow;

import com.eshore.framework.OnEnable;
import com.eshore.framework.StandardComponent;
import com.eshore.framework.StandardProperty;
import com.eshore.nlp.common.CharacterMap;
import com.eshore.utils.CallbackException;
import com.eshore.utils.TextIOUtils;
import com.eshore.writable.WritableInteger;
import com.hankcs.algorithm.AhoCorasickDoubleArrayTrie;
import java.io.IOException;
import java.io.InputStream;
import java.text.Normalizer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.function.Consumer;

@StandardComponent("BERT模型")
/* loaded from: input_file:com/eshore/tensorflow/BertTokenizer.class */
public class BertTokenizer {

    @StandardProperty(name = "单次序列最长长度", description = "单次序列最长长度", defaultValue = "510")
    private int sequeceLength;

    @StandardProperty(name = "英文转小写", description = "是否英文转小写", defaultValue = "true")
    private boolean toLowerCase;

    @StandardProperty(name = "text of vocabulary", description = "text of vocabulary", defaultValue = "vocab.txt")
    private Consumer<Consumer<InputStream>> vocabText;
    private Map<String, Integer> vocab;
    protected Tokenizer tokenizer;
    private String[] vocabulary;
    private AhoCorasickDoubleArrayTrie<Span> vocabTrie = new AhoCorasickDoubleArrayTrie<>();
    private CharacterMap characterMap = new CharacterMap();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/eshore/tensorflow/BertTokenizer$Span.class */
    public static class Span {
        public int start;
        public int end;

        protected Span() {
        }

        public String toString() {
            return "[" + this.start + "-" + this.end + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/eshore/tensorflow/BertTokenizer$Tokenizer.class */
    public class Tokenizer {
        float[] inputMask;
        int[] inputIDs;
        int[] segmentIDs;
        public Span[] wordPieces;
        StringBuilder buffer;
        int currentIndex;
        int sourceIndex;
        int segmentID;

        Tokenizer(int i) {
            this.inputIDs = new int[i + 2];
            this.inputMask = new float[this.inputIDs.length];
            Arrays.fill(this.inputMask, 1.0f);
            this.segmentIDs = new int[this.inputIDs.length];
            this.wordPieces = new Span[i];
            for (int i2 = 0; i2 < this.wordPieces.length; i2++) {
                this.wordPieces[i2] = new Span();
            }
            this.buffer = new StringBuilder(i);
        }

        public int getLength() {
            return this.currentIndex + 1;
        }

        void reset() {
            Arrays.fill(this.inputIDs, 0);
            Arrays.fill(this.segmentIDs, 0);
            Arrays.fill(this.inputMask, 0.0f);
            this.inputIDs[0] = ((Integer) BertTokenizer.this.vocab.get("[CLS]")).intValue();
            this.currentIndex = 0;
            this.sourceIndex = 0;
            this.segmentID = 0;
            this.segmentIDs[0] = 0;
            this.buffer.setLength(0);
        }

        public void flush() {
            if (this.buffer.length() == 0) {
                return;
            }
            List parseText = BertTokenizer.this.vocabTrie.parseText(this.buffer);
            PriorityQueue priorityQueue = new PriorityQueue((hit, hit2) -> {
                int i = hit.begin - hit2.begin;
                return i != 0 ? i : hit2.end - hit.end;
            });
            priorityQueue.addAll(parseText);
            int i = 0;
            while (priorityQueue.size() > 0) {
                AhoCorasickDoubleArrayTrie.Hit hit3 = (AhoCorasickDoubleArrayTrie.Hit) priorityQueue.poll();
                if (hit3.begin >= i) {
                    if (i == 0) {
                        if (((Span) hit3.value).start != -1) {
                            increase(((Span) hit3.value).start, i, hit3.begin, hit3.end);
                            i = hit3.end;
                        }
                    } else if (((Span) hit3.value).end != -1) {
                        increase(((Span) hit3.value).end, i, hit3.begin, hit3.end);
                        i = hit3.end;
                    }
                }
            }
            for (int i2 = i; i2 < this.buffer.length(); i2++) {
                increase(((Integer) BertTokenizer.this.vocab.get("[UNK]")).intValue());
            }
            this.buffer.setLength(0);
        }

        private void increase(int i, int i2, int i3, int i4) {
            for (int i5 = i2; i5 < i3; i5++) {
                increase(((Integer) BertTokenizer.this.vocab.get("[UNK]")).intValue());
            }
            increase(i, i4 - i3);
        }

        private void increase(int i) {
            increase(i, 1);
        }

        private void increase(int i, int i2) {
            Span span = this.wordPieces[this.currentIndex];
            span.start = this.sourceIndex;
            span.end = this.sourceIndex + i2;
            this.currentIndex++;
            this.inputIDs[this.currentIndex] = i;
            this.segmentIDs[this.currentIndex] = this.segmentID;
            this.sourceIndex += i2;
        }

        public void put(char c) {
            flush();
            Integer num = (Integer) BertTokenizer.this.vocab.get(String.valueOf(c));
            if (num == null) {
                num = (Integer) BertTokenizer.this.vocab.get("[UNK]");
            }
            increase(num.intValue());
        }

        public void end() {
            flush();
            this.currentIndex++;
            this.inputIDs[this.currentIndex] = ((Integer) BertTokenizer.this.vocab.get("[SEP]")).intValue();
            this.segmentIDs[this.currentIndex] = this.segmentID;
            this.segmentID++;
        }

        public void append(char c) {
            this.buffer.append(c);
        }

        public void increase() {
            this.sourceIndex++;
        }
    }

    @OnEnable
    private void onEnable() throws IOException {
        HashMap hashMap = new HashMap();
        this.vocab = new HashMap();
        try {
            this.vocabText.accept(inputStream -> {
                WritableInteger writableInteger = new WritableInteger();
                try {
                    TextIOUtils.processText(inputStream, str -> {
                        if (str.length() > 1 || !isChinese(str.charAt(0))) {
                            if (str.startsWith("##")) {
                                String substring = str.substring(2);
                                Span span = (Span) hashMap.get(substring);
                                if (span == null) {
                                    span = new Span();
                                    span.start = -1;
                                    span.end = -1;
                                    hashMap.put(substring, span);
                                }
                                span.end = writableInteger.get();
                            } else {
                                Span span2 = (Span) hashMap.get(str);
                                if (span2 == null) {
                                    span2 = new Span();
                                    span2.start = -1;
                                    span2.end = -1;
                                    hashMap.put(str, span2);
                                }
                                span2.start = writableInteger.get();
                            }
                        }
                        this.vocab.put(str, Integer.valueOf(writableInteger.increase()));
                        return true;
                    });
                } catch (IOException e) {
                    throw new CallbackException(e);
                }
            });
            this.vocabulary = new String[this.vocab.size()];
            this.vocab.forEach((str, num) -> {
                this.vocabulary[num.intValue()] = str;
            });
            this.vocabTrie.build(hashMap);
            this.tokenizer = new Tokenizer(this.sequeceLength);
        } catch (CallbackException e) {
            throw ((IOException) e.getInnerException(IOException.class));
        }
    }

    public int getSequeceLength() {
        return this.sequeceLength;
    }

    public String normalize(String str) {
        return this.characterMap.normalize(str);
    }

    public String[] tokenize(String str) {
        process(str);
        return id2token();
    }

    protected String[] id2token() {
        return id2token(this.tokenizer.inputIDs, this.tokenizer.getLength());
    }

    private String[] id2token(int[] iArr, int i) {
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < i; i2++) {
            strArr[i2] = this.vocabulary[iArr[i2]];
        }
        return strArr;
    }

    public void process(String... strArr) {
        this.tokenizer.reset();
        for (String str : strArr) {
            for (int i = 0; i < str.length(); i++) {
                char c = this.characterMap.get(str.charAt(i));
                if (c == 0 || c == ' ') {
                    this.tokenizer.flush();
                } else if (isChinese(c)) {
                    this.tokenizer.put(c);
                } else if (Character.getType(c) == 6) {
                    this.tokenizer.increase();
                } else {
                    if (this.toLowerCase) {
                        char[] charArray = Normalizer.normalize(String.valueOf(c).toLowerCase(), Normalizer.Form.NFD).toCharArray();
                        int length = charArray.length;
                        int i2 = 0;
                        while (true) {
                            if (i2 >= length) {
                                break;
                            }
                            char c2 = charArray[i2];
                            if (Character.getType(c2) != 6) {
                                c = c2;
                                break;
                            }
                            i2++;
                        }
                    }
                    if (isPunctuation(c)) {
                        this.tokenizer.put(c);
                    } else {
                        this.tokenizer.append(c);
                    }
                }
            }
            this.tokenizer.end();
        }
    }

    private static boolean isPunctuation(char c) {
        int type;
        if (c >= '!' && c <= '/') {
            return true;
        }
        if (c >= ':' && c <= '@') {
            return true;
        }
        if (c < '[' || c > '`') {
            return (c >= '{' && c <= '~') || (type = Character.getType(c)) == 23 || type == 20 || type == 22 || type == 30 || type == 29 || type == 24 || type == 21;
        }
        return true;
    }

    private static boolean isChinese(char c) {
        if (c >= 19968 && c <= 40959) {
            return true;
        }
        if (c < 13312 || c > 19903) {
            return c >= 63744 && c <= 64255;
        }
        return true;
    }

    public int getLength() {
        return this.tokenizer.getLength();
    }

    public int[] getInputIdBuffer() {
        return this.tokenizer.inputIDs;
    }

    public int[] getSegmentIdBuffer() {
        return this.tokenizer.segmentIDs;
    }

    public String getWord(int i) {
        return this.vocabulary[i];
    }

    public int getId(String str) {
        return this.vocab.get(str).intValue();
    }
}
