package de.datexis.encoder.impl;

import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/LetterNGramEncoder.class */
public class LetterNGramEncoder extends LookupCacheEncoder {
    protected int n;

    public LetterNGramEncoder() {
        super("TRI");
        this.log = LoggerFactory.getLogger(LetterNGramEncoder.class);
    }

    public LetterNGramEncoder(String str) {
        super(str);
        this.log = LoggerFactory.getLogger(LetterNGramEncoder.class);
    }

    public LetterNGramEncoder(int i) {
        super("TRI");
        this.log = LoggerFactory.getLogger(LetterNGramEncoder.class);
        this.n = i;
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return Integer.toString(this.n) + "-gram Encoder";
    }

    public int getN() {
        return this.n;
    }

    public LetterNGramEncoder setN(int i) {
        this.n = i;
        return this;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        Iterator<String> it = generateNGrams(str).iterator();
        while (it.hasNext()) {
            int index = getIndex(it.next());
            if (index >= 0) {
                zeros.put(index, 0, Double.valueOf(1.0d));
            }
        }
        return zeros;
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public boolean isUnknown(String str) {
        Iterator<String> it = generateNGrams(str).iterator();
        while (it.hasNext()) {
            if (!this.vocab.containsWord(it.next())) {
                return true;
            }
        }
        return false;
    }

    public String keepOnlyPrintableChars(String str) {
        return str.replaceAll("[^\\p{L}\\p{N}\\p{P}\\p{Sm}\\p{Sc}]", "").toLowerCase();
    }

    public List<String> generateNGrams(String str, int i) {
        String str2 = "#" + keepOnlyPrintableChars(str) + "#";
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 <= str2.length() - i; i2++) {
            arrayList.add(str2.substring(i2, i2 + i));
        }
        return arrayList;
    }

    public List<String> generateNGrams(String str) {
        return generateNGrams(str, this.n);
    }

    public List<String> getTrigramsFromProbabilityVector(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < iNDArray.length(); i++) {
            if (iNDArray.getDouble(i) > 0.5d) {
                arrayList.add(getWord(i));
            }
        }
        return arrayList;
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        trainModel(collection, 1);
    }

    public void trainModel(Collection<Document> collection, int i) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.timer.start();
        this.totalWords = 0;
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            Iterator<Token> it2 = it.next().getTokens().iterator();
            while (it2.hasNext()) {
                for (String str : generateNGrams(it2.next().getText())) {
                    this.totalWords++;
                    if (this.vocab.containsWord(str)) {
                        this.vocab.incrementWordCounter(str);
                    } else {
                        this.vocab.addWord(str);
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " " + this.n + "-grams (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }
}
