package com.omega.example.transformer.utils;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MathUtils;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:com/omega/example/transformer/utils/CNTokenizer.class */
public class CNTokenizer extends BaseTokenizer {
    private String dataPath;
    private int dataSize;
    public int characters;
    public Map<Character, Integer> dictionary;
    private Character[] data;
    public Character[] dictionaryData;
    public String[] vocab;
    public int inputType;
    public int time;
    public int number;
    public int batchSize;
    public Tensor testInput;
    public Character[] trainData;
    public Character[] vailData;
    private float vailRatio;

    public CNTokenizer(String str, int i, int i2) {
        this.dataSize = 0;
        this.characters = 1;
        this.dictionary = new HashMap();
        this.inputType = 0;
        this.vailRatio = 0.1f;
        this.dataPath = str;
        this.time = i;
        loadDataForTXT();
        this.dataSize = this.data.length;
        this.number = this.dataSize - i;
        this.characters = this.dictionary.size();
        System.out.println("dataSize[" + this.dataSize + "] characters[" + this.characters + "]");
        this.batchSize = i2;
        buildData();
    }

    public CNTokenizer(String str, int i, int i2, int i3) {
        this.dataSize = 0;
        this.characters = 1;
        this.dictionary = new HashMap();
        this.inputType = 0;
        this.vailRatio = 0.1f;
        this.dataPath = str;
        this.time = i;
        this.inputType = i3;
        loadDataForTXT();
        this.dataSize = this.data.length;
        this.number = this.dataSize - i;
        this.characters = this.dictionary.size();
        System.out.println("dataSize[" + this.dataSize + "] characters[" + this.characters + "]");
        this.batchSize = i2;
        buildData();
    }

    public void buildData() {
        int i = (int) (this.number * (1.0f - this.vailRatio));
        int i2 = this.number - i;
        this.trainData = new Character[i];
        this.vailData = new Character[i2];
        System.arraycopy(this.data, 0, this.trainData, 0, this.trainData.length);
        System.arraycopy(this.data, this.trainData.length, this.vailData, 0, this.vailData.length);
    }

    /* JADX WARN: Finally extract failed */
    public void loadDataForTXT() {
        try {
            FileInputStream fileInputStream = new FileInputStream(this.dataPath);
            Throwable th = null;
            try {
                InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream);
                Throwable th2 = null;
                try {
                    BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
                    Throwable th3 = null;
                    try {
                        try {
                            int i = 0;
                            ArrayList arrayList = new ArrayList();
                            while (true) {
                                String readLine = bufferedReader.readLine();
                                if (readLine == null) {
                                    break;
                                }
                                for (char c : readLine.toCharArray()) {
                                    arrayList.add(Character.valueOf(c));
                                    if (!this.dictionary.containsKey(Character.valueOf(c))) {
                                        this.dictionary.put(Character.valueOf(c), Integer.valueOf(i));
                                        i++;
                                    }
                                }
                            }
                            this.dictionaryData = new Character[this.dictionary.size()];
                            this.vocab = new String[this.dictionary.size()];
                            for (Character ch : this.dictionary.keySet()) {
                                this.dictionaryData[this.dictionary.get(ch).intValue()] = ch;
                                this.vocab[this.dictionary.get(ch).intValue()] = ch.toString();
                            }
                            this.data = new Character[arrayList.size()];
                            this.data = (Character[]) arrayList.toArray(this.data);
                            if (bufferedReader != null) {
                                if (0 != 0) {
                                    try {
                                        bufferedReader.close();
                                    } catch (Throwable th4) {
                                        th3.addSuppressed(th4);
                                    }
                                } else {
                                    bufferedReader.close();
                                }
                            }
                            if (inputStreamReader != null) {
                                if (0 != 0) {
                                    try {
                                        inputStreamReader.close();
                                    } catch (Throwable th5) {
                                        th2.addSuppressed(th5);
                                    }
                                } else {
                                    inputStreamReader.close();
                                }
                            }
                            if (fileInputStream != null) {
                                if (0 != 0) {
                                    try {
                                        fileInputStream.close();
                                    } catch (Throwable th6) {
                                        th.addSuppressed(th6);
                                    }
                                } else {
                                    fileInputStream.close();
                                }
                            }
                        } catch (Throwable th7) {
                            th3 = th7;
                            throw th7;
                        }
                    } catch (Throwable th8) {
                        if (bufferedReader != null) {
                            if (th3 != null) {
                                try {
                                    bufferedReader.close();
                                } catch (Throwable th9) {
                                    th3.addSuppressed(th9);
                                }
                            } else {
                                bufferedReader.close();
                            }
                        }
                        throw th8;
                    }
                } catch (Throwable th10) {
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th11) {
                                th2.addSuppressed(th11);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    throw th10;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public int[][] shuffle() {
        return MathUtils.randomInts(this.number, this.batchSize);
    }

    public void loadData(int[] iArr, Tensor tensor, Tensor tensor2) {
        tensor.clear();
        tensor2.clear();
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < this.time; i2++) {
                format(i, iArr[i], i2, this.trainData, tensor, tensor2);
            }
        }
        tensor.hostToDevice();
        tensor2.hostToDevice();
    }

    public void loadDataVail(int[] iArr, Tensor tensor, Tensor tensor2) {
        tensor.clear();
        tensor2.clear();
        for (int i = 0; i < iArr.length; i++) {
            for (int i2 = 0; i2 < this.time; i2++) {
                format(i, iArr[i], i2, this.vailData, tensor, tensor2);
            }
        }
        tensor.hostToDevice();
        tensor2.hostToDevice();
    }

    public void format(int i, int i2, int i3, Character[] chArr, Tensor tensor, Tensor tensor2) {
        char charValue = chArr[i2 + i3].charValue();
        char charValue2 = chArr[i2 + i3 + 1].charValue();
        if (this.inputType == 1) {
            tensor.data[(i * this.time) + i3] = this.dictionary.get(Character.valueOf(charValue)).intValue();
            tensor2.data[(((i * this.time) + i3) * this.characters) + this.dictionary.get(Character.valueOf(charValue2)).intValue()] = 1.0f;
        } else {
            tensor.data[(i * this.time) + i3] = this.dictionary.get(Character.valueOf(charValue)).intValue();
            tensor2.data[(((i * this.time) + i3) * this.characters) + this.dictionary.get(Character.valueOf(charValue2)).intValue()] = 1.0f;
        }
    }

    public void format(int i, int i2, int i3, Tensor tensor, Tensor tensor2) {
        char charValue = this.data[i2 + i3].charValue();
        char charValue2 = this.data[i2 + i3 + 1].charValue();
        if (this.inputType == 1) {
            tensor.data[(i * this.time) + i3] = this.dictionary.get(Character.valueOf(charValue)).intValue();
            tensor2.data[(((i * this.time) + i3) * this.characters) + this.dictionary.get(Character.valueOf(charValue2)).intValue()] = 1.0f;
        } else {
            tensor.data[(i * this.time) + i3] = this.dictionary.get(Character.valueOf(charValue)).intValue();
            tensor2.data[(((i * this.time) + i3) * this.characters) + this.dictionary.get(Character.valueOf(charValue2)).intValue()] = 1.0f;
        }
    }

    public void format(int i, Tensor tensor, char c) {
        tensor.data[(i * this.characters) + this.dictionary.get(Character.valueOf(c)).intValue()] = 1.0f;
    }

    public void loadData(int i, int i2, Tensor tensor, Tensor tensor2) {
        loadData(getIndexsByAsc(i, i2), tensor, tensor2);
    }

    public Tensor initLabelTensor() {
        return new Tensor(this.time * this.batchSize, 1, 1, this.characters, true);
    }

    public int[] getIndexsByAsc(int i, int i2) {
        int i3 = i * i2;
        int i4 = (i * i2) + i2;
        if (i4 > this.number) {
            i3 -= i4 - this.number;
        }
        int[] iArr = new int[i2];
        for (int i5 = 0; i5 < i2; i5++) {
            iArr[i5] = i3 + i5;
        }
        return iArr;
    }

    public static Tensor getPositions(int i, int i2) {
        float[] fArr = new float[i * i2 * i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                fArr[(i3 * i2 * i2) + (i4 * i2) + i4] = 1.0f;
            }
        }
        return new Tensor(i * i2, 1, 1, i2, fArr, true);
    }

    public static Tensor triu(int i, int i2, int i3, int i4, float f) {
        float[] fArr = new float[i * i2 * i3 * i4];
        for (int i5 = 0; i5 < i; i5++) {
            for (int i6 = 0; i6 < i2; i6++) {
                for (int i7 = 0; i7 < i3; i7++) {
                    for (int i8 = 0; i8 < i4; i8++) {
                        if (i7 < i8) {
                            fArr[(i5 * i2 * i3 * i4) + (i6 * i3 * i4) + (i7 * i3) + i8] = f;
                        }
                    }
                }
            }
        }
        return new Tensor(i, i2, i3, i4, fArr, true);
    }

    public Tensor loadByTxt(String str) {
        char[] charArray = str.toCharArray();
        System.out.println(JsonUtils.toJson(charArray));
        this.testInput = Tensor.createTensor(this.testInput, charArray.length, 1, 1, this.characters, true);
        this.testInput.clear();
        for (int i = 0; i < charArray.length; i++) {
            format(i, this.testInput, charArray[i]);
        }
        this.testInput.hostToDevice();
        return this.testInput;
    }
}
