package com.omega.example.transformer.utils;

import com.omega.common.data.Tensor;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/omega/example/transformer/utils/ENTokenizer.class */
public class ENTokenizer {
    public int number;
    private int batchSize;
    private final String[] _patterns;
    private final String[] _replacements;
    private String dataPath;
    public Map<String, Integer> dictionary;
    private List<String> org_tokens;
    public List<String[]> tokens;
    private final String[] specials;
    public int max_len;
    public int vocab_size;
    public String[] vocab;
    public Tensor testInput;
    public int formatType;

    public ENTokenizer(String str, int i, int i2) {
        this.number = 0;
        this.batchSize = 1;
        this._patterns = new String[]{"\\'", "\\\"", "\\.", "<br />", "\\,", "\\(", "\\)", "\\!", "\\?", "\\;", "\\:", "\\s+"};
        this._replacements = new String[]{" '  ", "", " . ", " ", " , ", " ( ", " ) ", " ! ", " ? ", " ", " ", " "};
        this.dictionary = new HashMap();
        this.org_tokens = new ArrayList();
        this.tokens = new ArrayList();
        this.specials = new String[]{"<pad>", "<sos>", "<eos>"};
        this.max_len = 256;
        this.formatType = 0;
        this.dataPath = str;
        this.max_len = i;
        this.batchSize = i2;
        loadDataForTXT();
        this.number = this.org_tokens.size();
        System.out.println(this.number);
    }

    public ENTokenizer(String str, int i, int i2, int i3) {
        this.number = 0;
        this.batchSize = 1;
        this._patterns = new String[]{"\\'", "\\\"", "\\.", "<br />", "\\,", "\\(", "\\)", "\\!", "\\?", "\\;", "\\:", "\\s+"};
        this._replacements = new String[]{" '  ", "", " . ", " ", " , ", " ( ", " ) ", " ! ", " ? ", " ", " ", " "};
        this.dictionary = new HashMap();
        this.org_tokens = new ArrayList();
        this.tokens = new ArrayList();
        this.specials = new String[]{"<pad>", "<sos>", "<eos>"};
        this.max_len = 256;
        this.formatType = 0;
        this.dataPath = str;
        this.max_len = i;
        this.batchSize = i2;
        this.formatType = i3;
        if (i3 > 0) {
            loadDataForTXTByChar();
        } else {
            loadDataForTXT();
        }
        this.number = this.org_tokens.size();
        System.out.println(this.number);
    }

    /* JADX WARN: Finally extract failed */
    public void loadDataForTXT() {
        try {
            FileInputStream fileInputStream = new FileInputStream(this.dataPath);
            Throwable th = null;
            try {
                InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream);
                Throwable th2 = null;
                try {
                    BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
                    Throwable th3 = null;
                    while (true) {
                        try {
                            try {
                                String readLine = bufferedReader.readLine();
                                String str = readLine;
                                if (readLine == null) {
                                    break;
                                }
                                for (int i = 0; i < this._patterns.length; i++) {
                                    str = str.replaceAll(this._patterns[i], this._replacements[i]);
                                }
                                String lowerCase = str.toLowerCase();
                                if (!lowerCase.equals(" ")) {
                                    this.org_tokens.add("<sos>" + lowerCase + "<eos>");
                                }
                            } catch (Throwable th4) {
                                th3 = th4;
                                throw th4;
                            }
                        } catch (Throwable th5) {
                            if (bufferedReader != null) {
                                if (th3 != null) {
                                    try {
                                        bufferedReader.close();
                                    } catch (Throwable th6) {
                                        th3.addSuppressed(th6);
                                    }
                                } else {
                                    bufferedReader.close();
                                }
                            }
                            throw th5;
                        }
                    }
                    buildVocab();
                    if (bufferedReader != null) {
                        if (0 != 0) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th7) {
                                th3.addSuppressed(th7);
                            }
                        } else {
                            bufferedReader.close();
                        }
                    }
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th8) {
                                th2.addSuppressed(th8);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th9) {
                                th.addSuppressed(th9);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                } catch (Throwable th10) {
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th11) {
                                th2.addSuppressed(th11);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    throw th10;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Finally extract failed */
    public void loadDataForTXTByChar() {
        try {
            FileInputStream fileInputStream = new FileInputStream(this.dataPath);
            Throwable th = null;
            try {
                InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream);
                Throwable th2 = null;
                try {
                    BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
                    Throwable th3 = null;
                    while (true) {
                        try {
                            try {
                                String readLine = bufferedReader.readLine();
                                String str = readLine;
                                if (readLine == null) {
                                    break;
                                }
                                for (int i = 0; i < this._patterns.length; i++) {
                                    str = str.replaceAll(this._patterns[i], this._replacements[i]);
                                }
                                String lowerCase = str.toLowerCase();
                                if (!lowerCase.equals(" ")) {
                                    this.org_tokens.add("<sos>" + lowerCase + "<eos>");
                                }
                            } catch (Throwable th4) {
                                th3 = th4;
                                throw th4;
                            }
                        } catch (Throwable th5) {
                            if (bufferedReader != null) {
                                if (th3 != null) {
                                    try {
                                        bufferedReader.close();
                                    } catch (Throwable th6) {
                                        th3.addSuppressed(th6);
                                    }
                                } else {
                                    bufferedReader.close();
                                }
                            }
                            throw th5;
                        }
                    }
                    buildVocabByChar();
                    if (bufferedReader != null) {
                        if (0 != 0) {
                            try {
                                bufferedReader.close();
                            } catch (Throwable th7) {
                                th3.addSuppressed(th7);
                            }
                        } else {
                            bufferedReader.close();
                        }
                    }
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th8) {
                                th2.addSuppressed(th8);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th9) {
                                th.addSuppressed(th9);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                } catch (Throwable th10) {
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th11) {
                                th2.addSuppressed(th11);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    throw th10;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void buildVocab() {
        for (int i = 0; i < this.specials.length; i++) {
            this.dictionary.put(this.specials[i], Integer.valueOf(i));
        }
        int length = this.specials.length;
        for (int i2 = 0; i2 < this.org_tokens.size(); i2++) {
            String[] split = this.org_tokens.get(i2).split(" ");
            if (split.length > 1) {
                this.tokens.add(split);
                for (int i3 = 0; i3 < split.length; i3++) {
                    if (!split[i3].equals("") && !this.dictionary.containsKey(split[i3])) {
                        this.dictionary.put(split[i3], Integer.valueOf(length));
                        length++;
                    }
                }
            }
        }
        this.vocab_size = this.dictionary.size();
        this.vocab = new String[this.vocab_size];
        for (String str : this.dictionary.keySet()) {
            this.vocab[this.dictionary.get(str).intValue()] = str;
        }
    }

    public void buildVocabByChar() {
        for (int i = 0; i < this.specials.length; i++) {
            this.dictionary.put(this.specials[i], Integer.valueOf(i));
        }
        int length = this.specials.length;
        for (int i2 = 0; i2 < this.org_tokens.size(); i2++) {
            String[] split = this.org_tokens.get(i2).split("");
            if (split.length > 1) {
                this.tokens.add(split);
                for (int i3 = 0; i3 < split.length; i3++) {
                    if (!split[i3].equals("") && !this.dictionary.containsKey(split[i3])) {
                        this.dictionary.put(split[i3], Integer.valueOf(length));
                        length++;
                    }
                }
            }
        }
        this.vocab_size = this.dictionary.size();
        this.vocab = new String[this.vocab_size];
        for (String str : this.dictionary.keySet()) {
            this.vocab[this.dictionary.get(str).intValue()] = str;
        }
    }

    public Tensor loadByTxt(String str) {
        char[] charArray = str.toLowerCase().toCharArray();
        this.testInput = Tensor.createTensor(this.testInput, this.max_len, 1, 1, this.vocab_size, true);
        this.testInput.clear();
        for (int i = 0; i < this.max_len; i++) {
            int intValue = this.dictionary.get("<PAD>").intValue();
            if (i < charArray.length) {
                intValue = this.dictionary.get(charArray[i] + "").intValue();
            }
            this.testInput.data[(i * this.vocab_size) + intValue] = 1.0f;
        }
        this.testInput.hostToDevice();
        return this.testInput;
    }

    public void loadData(int[] iArr, Tensor tensor, Tensor tensor2) {
        tensor.clear();
        tensor2.clear();
        for (int i = 0; i < iArr.length; i++) {
            String[] strArr = this.tokens.get(iArr[i]);
            for (int i2 = 0; i2 < this.max_len; i2++) {
                format(i, i2, strArr, tensor, tensor2);
            }
        }
        tensor.hostToDevice();
        tensor2.hostToDevice();
    }

    public void loadDataIdx(int[] iArr, Tensor tensor, Tensor tensor2) {
        tensor.clear();
        tensor2.clear();
        for (int i = 0; i < iArr.length; i++) {
            String[] strArr = this.tokens.get(iArr[i]);
            for (int i2 = 0; i2 < this.max_len; i2++) {
                formatIdx(i, i2, strArr, tensor, tensor2);
            }
        }
        tensor.hostToDevice();
        tensor2.hostToDevice();
    }

    public void format(int i, int i2, String[] strArr, Tensor tensor, Tensor tensor2) {
        if (i2 + 1 >= strArr.length) {
            tensor.data[(((i * this.max_len) + i2) * this.vocab_size) + 0] = 1.0f;
            tensor2.data[(((i * this.max_len) + i2) * this.vocab_size) + 0] = 1.0f;
        } else {
            String str = strArr[i2];
            String str2 = strArr[i2 + 1];
            tensor.data[(((i * this.max_len) + i2) * this.vocab_size) + this.dictionary.get(str).intValue()] = 1.0f;
            tensor2.data[(((i * this.max_len) + i2) * this.vocab_size) + this.dictionary.get(str2).intValue()] = 1.0f;
        }
    }

    public void formatIdx(int i, int i2, String[] strArr, Tensor tensor, Tensor tensor2) {
        if (i2 + 1 >= strArr.length) {
            tensor.data[(i * this.max_len) + i2] = this.dictionary.get("<pad>").intValue();
            tensor2.data[(((i * this.max_len) + i2) * this.vocab_size) + 0] = 1.0f;
        } else {
            String str = strArr[i2];
            String str2 = strArr[i2 + 1];
            tensor.data[(i * this.max_len) + i2] = this.dictionary.get(str).intValue();
            tensor2.data[(((i * this.max_len) + i2) * this.vocab_size) + this.dictionary.get(str2).intValue()] = 1.0f;
        }
    }

    public static Tensor getPositions(int i, int i2) {
        float[] fArr = new float[i * i2 * i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                fArr[(i3 * i2 * i2) + (i4 * i2) + i4] = 1.0f;
            }
        }
        return new Tensor(i * i2, 1, 1, i2, fArr, true);
    }

    public static Tensor triu(int i, int i2, int i3, int i4, float f) {
        float[] fArr = new float[i * i2 * i3 * i4];
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2; i6++) {
                for (int i7 = 0; i7 < i3; i7++) {
                    for (int i8 = 0; i8 < i4; i8++) {
                        if (i7 < i8) {
                            fArr[(i5 * i2 * i3 * i4) + (i6 * i3 * i4) + (i7 * i3) + i8] = f;
                        }
                    }
                }
            }
        }
        return new Tensor(i, i2, i3, i4, fArr, true);
    }

    public static void main(String[] strArr) {
        new ENTokenizer("H:\\transformer_dataset\\gpt\\wikitext-2-v1\\wikitext-2\\wiki.train.tokens", 256, 64).loadDataForTXT();
        triu(2, 4, 5, 5, 1.0f).showDM();
        getPositions(2, 4).showDM();
    }
}
