package org.deeplearning4j.models.embeddings.loader;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.UUID;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.class */
public class WordVectorSerializer {
    private static final int MAX_SIZE = 50;
    private static Logger log = LoggerFactory.getLogger(WordVectorSerializer.class);

    /* JADX WARN: Finally extract failed */
    public static Word2Vec loadGoogleModel(String str, boolean z) throws IOException {
        ArrayList arrayList = new ArrayList();
        File file = new File("." + UUID.randomUUID().toString());
        if (!file.mkdirs()) {
            throw new IllegalStateException("Unable to create directory for word vectors");
        }
        if (!z) {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(str)));
            String[] split = bufferedReader.readLine().split(" ");
            int parseInt = Integer.parseInt(split[0]);
            int parseInt2 = Integer.parseInt(split[1]);
            InMemoryLookupCache inMemoryLookupCache = new InMemoryLookupCache();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String[] split2 = readLine.split(" ");
                String str2 = split2[0];
                if (!str2.isEmpty()) {
                    float[] fArr = new float[parseInt2];
                    for (int i = 1; i < split2.length; i++) {
                        fArr[i - 1] = Float.parseFloat(split2[i]);
                    }
                    File file2 = new File(file, String.valueOf(inMemoryLookupCache.numWords()));
                    writeVector(fArr, file2);
                    arrayList.add(file2);
                    inMemoryLookupCache.addWordToIndex(inMemoryLookupCache.numWords(), str2);
                    inMemoryLookupCache.addToken(new VocabWord(1.0d, str2));
                    inMemoryLookupCache.putVocabWord(str2);
                }
            }
            WeightLookupTable build = new InMemoryLookupTable.Builder().cache(inMemoryLookupCache).vectorLength(parseInt2).build();
            build.resetWeights();
            for (int i2 = 0; i2 < parseInt; i2++) {
                build.putVector(inMemoryLookupCache.wordAtIndex(i2), Nd4j.create(readVec((File) arrayList.get(i2), parseInt2)));
                ((File) arrayList.get(i2)).delete();
            }
            Word2Vec word2Vec = new Word2Vec();
            word2Vec.setVocab(inMemoryLookupCache);
            word2Vec.setLookupTable(build);
            bufferedReader.close();
            file.delete();
            return word2Vec;
        }
        DataInputStream dataInputStream = null;
        BufferedInputStream bufferedInputStream = null;
        try {
            bufferedInputStream = new BufferedInputStream(new FileInputStream(str));
            dataInputStream = new DataInputStream(bufferedInputStream);
            int parseInt3 = Integer.parseInt(readString(dataInputStream));
            int parseInt4 = Integer.parseInt(readString(dataInputStream));
            InMemoryLookupCache inMemoryLookupCache2 = new InMemoryLookupCache();
            WeightLookupTable build2 = new InMemoryLookupTable.Builder().cache(inMemoryLookupCache2).vectorLength(parseInt4).build();
            for (int i3 = 0; i3 < parseInt3; i3++) {
                String readString = readString(dataInputStream);
                if (!readString.isEmpty()) {
                    float[] fArr2 = new float[parseInt4];
                    double d = 0.0d;
                    for (int i4 = 0; i4 < parseInt4; i4++) {
                        d += r0 * r0;
                        fArr2[i4] = readFloat(dataInputStream);
                    }
                    double sqrt = Math.sqrt(d);
                    for (int i5 = 0; i5 < parseInt4; i5++) {
                        fArr2[i5] = (float) (fArr2[r1] / sqrt);
                    }
                    File file3 = new File(file, String.valueOf(i3));
                    arrayList.add(file3);
                    writeVector(fArr2, file3);
                    inMemoryLookupCache2.addWordToIndex(inMemoryLookupCache2.numWords(), readString);
                    inMemoryLookupCache2.addToken(new VocabWord(1.0d, readString));
                    inMemoryLookupCache2.putVocabWord(readString);
                    dataInputStream.read();
                }
            }
            bufferedInputStream.close();
            dataInputStream.close();
            Word2Vec word2Vec2 = new Word2Vec();
            build2.resetWeights();
            for (int i6 = 0; i6 < arrayList.size(); i6++) {
                build2.putVector(inMemoryLookupCache2.wordAtIndex(i6), Nd4j.create(readVec((File) arrayList.get(i6), parseInt4)));
                ((File) arrayList.get(i6)).delete();
            }
            word2Vec2.setVocab(inMemoryLookupCache2);
            word2Vec2.setLookupTable(build2);
            file.delete();
            return word2Vec2;
        } catch (Throwable th) {
            bufferedInputStream.close();
            dataInputStream.close();
            throw th;
        }
    }

    private static float[] readVec(File file, int i) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
        float[] fArr = new float[i];
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2] = dataInputStream.readFloat();
        }
        dataInputStream.close();
        return fArr;
    }

    private static void writeVector(float[] fArr, File file) throws IOException {
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
        for (float f : fArr) {
            dataOutputStream.writeFloat(f);
        }
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
    }

    private static String readString(DataInputStream dataInputStream) throws IOException {
        byte[] bArr = new byte[MAX_SIZE];
        byte readByte = dataInputStream.readByte();
        int i = -1;
        StringBuilder sb = new StringBuilder();
        while (readByte != 32 && readByte != 10) {
            i++;
            bArr[i] = readByte;
            readByte = dataInputStream.readByte();
            if (i == 49) {
                sb.append(new String(bArr));
                i = -1;
                bArr = new byte[MAX_SIZE];
            }
        }
        sb.append(new String(bArr, 0, i + 1));
        return sb.toString();
    }

    public static float readFloat(InputStream inputStream) throws IOException {
        byte[] bArr = new byte[4];
        inputStream.read(bArr);
        return getFloat(bArr);
    }

    public static float getFloat(byte[] bArr) {
        return Float.intBitsToFloat(0 | ((bArr[0] & 255) << 0) | ((bArr[1] & 255) << 8) | ((bArr[2] & 255) << 16) | ((bArr[3] & 255) << 24));
    }

    public static void writeWordVectors(InMemoryLookupTable inMemoryLookupTable, InMemoryLookupCache inMemoryLookupCache, String str) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str), false));
        for (int i = 0; i < inMemoryLookupTable.getSyn0().rows(); i++) {
            String wordAtIndex = inMemoryLookupCache.wordAtIndex(i);
            if (wordAtIndex != null) {
                StringBuffer stringBuffer = new StringBuffer();
                stringBuffer.append(wordAtIndex);
                stringBuffer.append(" ");
                INDArray vector = inMemoryLookupTable.vector(wordAtIndex);
                for (int i2 = 0; i2 < vector.length(); i2++) {
                    stringBuffer.append(vector.getDouble(i2));
                    if (i2 < vector.length() - 1) {
                        stringBuffer.append(" ");
                    }
                }
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public static void writeWordVectors(Word2Vec word2Vec, String str) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str), false));
        for (String str2 : word2Vec.vocab().words()) {
            if (str2 != null) {
                StringBuffer stringBuffer = new StringBuffer();
                stringBuffer.append(str2);
                stringBuffer.append(" ");
                INDArray wordVectorMatrix = word2Vec.getWordVectorMatrix(str2);
                for (int i = 0; i < wordVectorMatrix.length(); i++) {
                    stringBuffer.append(wordVectorMatrix.getDouble(i));
                    if (i < wordVectorMatrix.length() - 1) {
                        stringBuffer.append(" ");
                    }
                }
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        log.info("Wrote 0 with size of " + word2Vec.lookupTable().layerSize());
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public static WordVectors loadTxtVectors(File file) throws FileNotFoundException {
        Pair<WeightLookupTable, VocabCache> loadTxt = loadTxt(file);
        WordVectorsImpl wordVectorsImpl = new WordVectorsImpl();
        wordVectorsImpl.setLookupTable((WeightLookupTable) loadTxt.getFirst());
        wordVectorsImpl.setVocab((VocabCache) loadTxt.getSecond());
        return wordVectorsImpl;
    }

    public static Pair<WeightLookupTable, VocabCache> loadTxt(File file) throws FileNotFoundException {
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        InMemoryLookupCache inMemoryLookupCache = new InMemoryLookupCache();
        LineIterator lineIterator = IOUtils.lineIterator(bufferedReader);
        ArrayList arrayList = new ArrayList();
        while (lineIterator.hasNext()) {
            String[] split = lineIterator.nextLine().split(" ");
            String str = split[0];
            inMemoryLookupCache.addToken(new VocabWord(1.0d, str));
            inMemoryLookupCache.addWordToIndex(inMemoryLookupCache.numWords(), str);
            inMemoryLookupCache.putVocabWord(str);
            INDArray create = Nd4j.create(Nd4j.createBuffer(split.length - 1));
            for (int i = 1; i < split.length; i++) {
                create.putScalar(i - 1, Float.parseFloat(split[i]));
            }
            arrayList.add(create);
        }
        INDArray create2 = Nd4j.create(new int[]{arrayList.size(), ((INDArray) arrayList.get(0)).columns()});
        for (int i2 = 0; i2 < create2.rows(); i2++) {
            create2.putRow(i2, (INDArray) arrayList.get(i2));
        }
        InMemoryLookupTable inMemoryLookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(((INDArray) arrayList.get(0)).columns()).useAdaGrad(false).cache(inMemoryLookupCache).build();
        Nd4j.clearNans(create2);
        inMemoryLookupTable.setSyn0(create2);
        lineIterator.close();
        return new Pair<>(inMemoryLookupTable, inMemoryLookupCache);
    }

    public static void writeTsneFormat(Glove glove, INDArray iNDArray, File file) throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        InMemoryLookupCache inMemoryLookupCache = (InMemoryLookupCache) glove.vocab();
        for (String str : glove.vocab().words()) {
            if (str != null) {
                StringBuffer stringBuffer = new StringBuffer();
                INDArray row = iNDArray.getRow(inMemoryLookupCache.wordFor(str).getIndex());
                for (int i = 0; i < row.length(); i++) {
                    stringBuffer.append(row.getDouble(i));
                    if (i < row.length() - 1) {
                        stringBuffer.append(",");
                    }
                }
                stringBuffer.append(",");
                stringBuffer.append(str);
                stringBuffer.append(" ");
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        log.info("Wrote 0 with size of " + glove.lookupTable().getVectorLength());
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public static void writeTsneFormat(Word2Vec word2Vec, INDArray iNDArray, File file) throws Exception {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        InMemoryLookupCache inMemoryLookupCache = (InMemoryLookupCache) word2Vec.vocab();
        for (String str : word2Vec.vocab().words()) {
            if (str != null) {
                StringBuffer stringBuffer = new StringBuffer();
                INDArray row = iNDArray.getRow(inMemoryLookupCache.wordFor(str).getIndex());
                for (int i = 0; i < row.length(); i++) {
                    stringBuffer.append(row.getDouble(i));
                    if (i < row.length() - 1) {
                        stringBuffer.append(",");
                    }
                }
                stringBuffer.append(",");
                stringBuffer.append(str);
                stringBuffer.append(" ");
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        log.info("Wrote 0 with size of " + word2Vec.lookupTable().layerSize());
        bufferedWriter.flush();
        bufferedWriter.close();
    }
}
