package de.datexis.encoder.impl;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import de.datexis.common.ObjectSerializer;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.LowercasePreprocessor;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(ignoreUnknown = true)
/* loaded from: input_file:de/datexis/encoder/impl/Word2VecEncoder.class */
public class Word2VecEncoder extends Encoder {
    private static final Logger log = LoggerFactory.getLogger(Word2VecEncoder.class);
    private static final Collection<String> FILENAMES_TEXT = Arrays.asList(".txt", ".txt.gz");
    private static final Collection<String> FILENAMES_BINARY = Arrays.asList(".bin", ".bin.gz");
    private static final Collection<String> FILENAMES_DL4J = Arrays.asList(".zip");
    private static final Collection<String> FILENAMES_GOOGLE = Arrays.asList(".zip");
    private WordVectors vec;
    private long length;
    private String modelName;
    private TokenPreProcess preprocessor;

    /* loaded from: input_file:de/datexis/encoder/impl/Word2VecEncoder$ModelType.class */
    public enum ModelType {
        TEXT,
        BINARY,
        DL4J,
        GOOGLE
    }

    /* loaded from: input_file:de/datexis/encoder/impl/Word2VecEncoder$SentenceStringIterator.class */
    public class SentenceStringIterator implements SentenceIterator {
        private Iterator<Sentence> it;
        Iterable<Sentence> sentences;
        private SentencePreProcessor spp;

        public SentenceStringIterator(Iterable<Sentence> iterable) {
            this.sentences = iterable;
            reset();
        }

        public String nextSentence() {
            return this.it.next().getText();
        }

        public boolean hasNext() {
            return this.it.hasNext();
        }

        public void reset() {
            this.it = this.sentences.iterator();
        }

        public void finish() {
            this.it.remove();
        }

        public SentencePreProcessor getPreProcessor() {
            return this.spp;
        }

        public void setPreProcessor(SentencePreProcessor sentencePreProcessor) {
            this.spp = sentencePreProcessor;
        }
    }

    public Word2VecEncoder() {
        super("EMB");
        this.preprocessor = new LowercasePreprocessor();
    }

    public Word2VecEncoder(String str) {
        super(str);
        this.preprocessor = new LowercasePreprocessor();
    }

    public static Word2VecEncoder load(Resource resource) {
        Word2VecEncoder word2VecEncoder = new Word2VecEncoder();
        word2VecEncoder.loadModel(resource);
        return word2VecEncoder;
    }

    public static Word2VecEncoder loadDummyEncoder() {
        Word2VecEncoder word2VecEncoder = new Word2VecEncoder();
        word2VecEncoder.loadModel(Resource.fromJAR("encoder/word2vec.txt"));
        return word2VecEncoder;
    }

    @Override // de.datexis.annotator.IComponent
    public void loadModel(Resource resource) {
        log.info("Loading Word2Vec model: {} with preprocessor {}", resource.getFileName(), getPreprocessorClass());
        try {
            switch (getModelType(resource.getFileName())) {
                case TEXT:
                default:
                    this.vec = WordVectorSerializer.loadTxtVectors(resource.getInputStream(), false);
                    break;
                case BINARY:
                    this.vec = loadBinaryModel(resource.getInputStream());
                    break;
                case DL4J:
                    this.vec = WordVectorSerializer.loadStaticModel(resource.toFile());
                    break;
                case GOOGLE:
                    this.vec = WordVectorSerializer.loadStaticModel(resource.toFile());
                    break;
            }
            int numWords = this.vec.vocab().numWords();
            this.length = this.vec.getWordVectorMatrix(this.vec.vocab().wordAtIndex(0)).length();
            setModel(resource);
            setModelAvailable(true);
            log.info("Loaded Word2Vec model '" + resource.getFileName() + "' with " + numWords + " vectors of size " + this.length);
        } catch (IOException e) {
            log.error("could not load model " + e.toString());
        }
    }

    @Override // de.datexis.annotator.IComponent
    public void saveModel(Resource resource, String str) {
        saveModel(resource, str, ModelType.BINARY);
    }

    public void saveModel(Resource resource, String str, ModelType modelType) {
        Resource resource2;
        try {
            ObjectSerializer.writeJSON(this, resource.resolve("config.json"));
            switch (modelType) {
                case TEXT:
                    resource2 = resource.resolve(str + ".txt.gz");
                    WordVectorSerializer.writeWordVectors(this.vec, resource2.getGZIPOutputStream());
                    break;
                case BINARY:
                default:
                    resource2 = resource.resolve(str + ".bin");
                    writeBinaryModel(this.vec, resource2.getOutputStream());
                    break;
                case DL4J:
                    resource2 = resource.resolve(str + ".zip");
                    WordVectorSerializer.writeWord2VecModel(this.vec, resource2.getOutputStream());
                    break;
                case GOOGLE:
                    resource2 = null;
                    log.error("Cannot write Google Model");
                    break;
            }
            setModel(resource2);
        } catch (IOException e) {
            e.printStackTrace();
            log.error("Could not save model: " + e.toString());
        }
    }

    public void setPreprocessor(TokenPreProcess tokenPreProcess) {
        this.preprocessor = tokenPreProcess;
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        trainModel((Iterable<Sentence>) collection.stream().flatMap(document -> {
            return document.streamSentences();
        }).collect(Collectors.toList()), 16, 10, 3, 256, 5, 1, new ArrayList());
    }

    public void trainModel(Iterable<Sentence> iterable, int i, int i2, int i3, int i4, int i5, int i6, List<String> list) {
        trainModel(new SentenceStringIterator(iterable), i, i2, i3, i4, i5, i6, list);
    }

    public void trainModel(SentenceIterator sentenceIterator, int i, int i2, int i3, int i4, int i5, int i6, List<String> list) {
        Nd4j.create(1);
        DefaultTokenizerFactory defaultTokenizerFactory = new DefaultTokenizerFactory();
        defaultTokenizerFactory.setTokenPreProcessor(this.preprocessor);
        log.info("Building model....");
        this.vec = new Word2Vec.Builder().batchSize(i).windowSize(i2).minWordFrequency(i3).useAdaGrad(false).layerSize(i4).seed(42L).iterations(i5).epochs(i6).stopWords(list).learningRate(0.025d).minLearningRate(0.001d).negativeSample(10.0d).iterate(sentenceIterator).tokenizerFactory(defaultTokenizerFactory).build();
        log.info("Fitting Word2Vec model....");
        this.vec.fit();
    }

    public static ModelType getModelType(String str) {
        String lowerCase = str.toLowerCase();
        return FILENAMES_TEXT.stream().anyMatch(str2 -> {
            return lowerCase.endsWith(str2);
        }) ? ModelType.TEXT : FILENAMES_BINARY.stream().anyMatch(str3 -> {
            return lowerCase.endsWith(str3);
        }) ? ModelType.BINARY : FILENAMES_DL4J.stream().anyMatch(str4 -> {
            return lowerCase.endsWith(str4);
        }) ? ModelType.DL4J : FILENAMES_GOOGLE.stream().anyMatch(str5 -> {
            return lowerCase.endsWith(str5);
        }) ? ModelType.GOOGLE : ModelType.TEXT;
    }

    public Class getPreprocessorClass() {
        return this.preprocessor.getClass();
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return this.modelName;
    }

    private INDArray getWordVector(String str) {
        return this.vec.getWordVectorMatrix(this.preprocessor.preProcess(str));
    }

    public boolean isUnknown(String str) {
        return !this.vec.hasWord(this.preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return span instanceof Token ? encode(this.preprocessor.preProcess(span.getText())) : encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.length;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        int i = 0;
        for (String str2 : WordHelpers.splitSpaces(str)) {
            if (!str2.trim().isEmpty()) {
                INDArray wordVectorMatrix = this.vec.getWordVectorMatrix(this.preprocessor.preProcess(str2));
                if (wordVectorMatrix != null) {
                    zeros.addi(wordVectorMatrix.transpose());
                }
                i++;
            }
        }
        return i == 0 ? zeros : zeros.div(Integer.valueOf(i));
    }

    public Collection<String> getNearestNeighbours(String str, int i) {
        return this.vec.wordsNearest(this.preprocessor.preProcess(str), i);
    }

    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        Counter counter = new Counter();
        for (String str : this.vec.vocab().words()) {
            counter.incrementCount(str, Transforms.cosineSim(iNDArray, encode(str)));
        }
        counter.keepTopNElements(i);
        return counter.keySetSorted();
    }

    public String getNearestNeighbour(INDArray iNDArray) {
        Collection<String> nearestNeighbours = getNearestNeighbours(iNDArray, 1);
        return nearestNeighbours.isEmpty() ? "_" : nearestNeighbours.iterator().next();
    }

    private static void writeBinaryModel(WordVectors wordVectors, OutputStream outputStream) throws IOException {
        int i = 0;
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(outputStream);
        Throwable th = null;
        try {
            DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
            Throwable th2 = null;
            try {
                for (Object obj : wordVectors.vocab().words()) {
                    if (obj != null) {
                        INDArray wordVectorMatrix = wordVectors.getWordVectorMatrix((String) obj);
                        log.trace("Write: " + obj + " (size " + wordVectorMatrix.length() + ")");
                        dataOutputStream.writeUTF((String) obj);
                        Nd4j.write(wordVectorMatrix, dataOutputStream);
                        i++;
                    }
                }
                dataOutputStream.flush();
                if (dataOutputStream != null) {
                    if (0 != 0) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                log.info("Wrote " + i + " words with size " + wordVectors.lookupTable().layerSize());
            } catch (Throwable th4) {
                if (dataOutputStream != null) {
                    if (0 != 0) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (bufferedOutputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
        }
    }

    private static WordVectors loadBinaryModel(InputStream inputStream) throws IOException {
        AbstractCache build = new AbstractCache.Builder().build();
        ArrayList arrayList = new ArrayList();
        int i = 0;
        BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream);
        Throwable th = null;
        try {
            DataInputStream dataInputStream = new DataInputStream(bufferedInputStream);
            Throwable th2 = null;
            while (dataInputStream.available() > 0) {
                try {
                    try {
                        String readUTF = dataInputStream.readUTF();
                        INDArray read = Nd4j.read(dataInputStream);
                        VocabWord vocabWord = new VocabWord(1.0d, readUTF);
                        vocabWord.setIndex(build.numWords());
                        build.addToken(vocabWord);
                        build.addWordToIndex(vocabWord.getIndex(), readUTF);
                        build.putVocabWord(readUTF);
                        arrayList.add(read);
                        i++;
                    } catch (Throwable th3) {
                        if (dataInputStream != null) {
                            if (th2 != null) {
                                try {
                                    dataInputStream.close();
                                } catch (Throwable th4) {
                                    th2.addSuppressed(th4);
                                }
                            } else {
                                dataInputStream.close();
                            }
                        }
                        throw th3;
                    }
                } finally {
                }
            }
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th5) {
                        th2.addSuppressed(th5);
                    }
                } else {
                    dataInputStream.close();
                }
            }
            InMemoryLookupTable build2 = new InMemoryLookupTable.Builder().vectorLength(((INDArray) arrayList.get(0)).columns()).cache(build).build();
            INDArray vstack = Nd4j.vstack(arrayList);
            Nd4j.clearNans(vstack);
            build2.setSyn0(vstack);
            return WordVectorSerializer.fromPair(Pair.makePair(build2, build));
        } finally {
            if (bufferedInputStream != null) {
                if (0 != 0) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
        }
    }
}
