package de.datexis.encoder.impl;

import com.fasterxml.jackson.annotation.JsonIgnore;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.MinimalLowercaseNewlinePreprocessor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.commons.math3.util.Pair;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/BagOfWordsEncoder.class */
public class BagOfWordsEncoder extends LookupCacheEncoder {
    protected TokenPreProcess preprocessor;
    protected WordHelpers wordHelpers;
    protected WordHelpers.Language language;

    public BagOfWordsEncoder() {
        this("BOW");
    }

    public BagOfWordsEncoder(String str) {
        super(str);
        this.preprocessor = new MinimalLowercaseNewlinePreprocessor();
        this.log = LoggerFactory.getLogger(BagOfWordsEncoder.class);
        this.vocab = new VocabularyHolder.Builder().build();
    }

    public Class getPreprocessorClass() {
        return this.preprocessor.getClass();
    }

    @JsonIgnore
    public TokenPreProcess getPreprocessor() {
        return this.preprocessor;
    }

    public void setPreprocessor(TokenPreProcess tokenPreProcess) {
        this.preprocessor = tokenPreProcess;
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return "Bag-of-words Encoder";
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        trainModel(collection, 1, WordHelpers.Language.EN);
    }

    public void trainModel(Collection<Document> collection, int i, WordHelpers.Language language) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.totalWords = 0;
        this.timer.start();
        setLanguage(language);
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            Iterator<Token> it2 = it.next().getTokens().iterator();
            while (it2.hasNext()) {
                String preProcess = this.preprocessor.preProcess(it2.next().getText());
                if (!preProcess.isEmpty()) {
                    this.totalWords++;
                    if (!this.wordHelpers.isStopWord(preProcess)) {
                        if (this.vocab.containsWord(preProcess)) {
                            this.vocab.incrementWordCounter(preProcess);
                        } else {
                            this.vocab.addWord(preProcess);
                        }
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " words (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    public void trainModel(Iterable<String> iterable, int i, int i2, WordHelpers.Language language) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.totalWords = 0;
        this.timer.start();
        setLanguage(language);
        Iterator<String> it = iterable.iterator();
        while (it.hasNext()) {
            for (String str : WordHelpers.splitSpaces(it.next())) {
                String preProcess = this.preprocessor.preProcess(str);
                if (!preProcess.isEmpty()) {
                    this.totalWords++;
                    if (!this.wordHelpers.isStopWord(preProcess) && preProcess.length() >= i2) {
                        if (this.vocab.containsWord(preProcess)) {
                            this.vocab.incrementWordCounter(preProcess);
                        } else {
                            this.vocab.addWord(preProcess);
                        }
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " words (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public boolean isUnknown(String str) {
        return super.isUnknown(this.preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public int getIndex(String str) {
        return super.getIndex(this.preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public int getFrequency(String str) {
        return super.getFrequency(this.preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public double getProbability(String str) {
        return super.getProbability(this.preprocessor.preProcess(str));
    }

    public WordHelpers.Language getLanguage() {
        return this.language;
    }

    public void setLanguage(WordHelpers.Language language) {
        this.language = language;
        this.wordHelpers = new WordHelpers(language);
    }

    @Override // de.datexis.encoder.Encoder, de.datexis.encoder.IEncoder
    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        Iterator<? extends Span> it = iterable.iterator();
        while (it.hasNext()) {
            int index = getIndex(it.next().getText());
            if (index >= 0) {
                zeros.put(index, 0, Double.valueOf(1.0d));
            }
        }
        return zeros;
    }

    protected INDArray encode(String[] strArr) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        for (String str : strArr) {
            int index = getIndex(str);
            if (index >= 0) {
                zeros.put(index, 0, Double.valueOf(1.0d));
            }
        }
        return zeros;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return span instanceof Token ? encode(Arrays.asList(span)) : span instanceof Sentence ? encode(((Sentence) span).getTokens()) : encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        return encode(WordHelpers.splitSpaces(str));
    }

    public INDArray encodeSubsampled(String str) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        String[] splitSpaces = WordHelpers.splitSpaces(str);
        if (splitSpaces.length == 1) {
            return encode(splitSpaces[0]);
        }
        ArrayList<Pair> arrayList = new ArrayList(5);
        double d = 0.0d;
        for (String str2 : splitSpaces) {
            String preProcess = this.preprocessor.preProcess(str2);
            if (!preProcess.isEmpty() && !this.wordHelpers.isStopWord(preProcess)) {
                double samplingRate = samplingRate(super.getProbability(preProcess));
                if (samplingRate != 1.0d) {
                    d += samplingRate;
                    arrayList.add(new Pair(preProcess, Double.valueOf(samplingRate)));
                }
            }
        }
        double random = Math.random() * d;
        double d2 = 0.0d;
        for (Pair pair : arrayList) {
            d2 += ((Double) pair.getValue()).doubleValue();
            if (d2 >= random) {
                int index = getIndex((String) pair.getKey());
                if (index >= 0) {
                    zeros.put(index, 0, Double.valueOf(1.0d));
                }
                return zeros;
            }
        }
        return zeros;
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public double getConfidence(INDArray iNDArray, int i) {
        return iNDArray.getDouble(i);
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public double getMaxConfidence(INDArray iNDArray) {
        return iNDArray.max(new int[]{0}).sumNumber().doubleValue();
    }

    public Set<String> asString(Iterable<Token> iterable) {
        HashSet hashSet = new HashSet();
        for (Token token : iterable) {
            if (!isUnknown(token.getText())) {
                hashSet.add(this.preprocessor.preProcess(token.getText()));
            }
        }
        return hashSet;
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public String getNearestNeighbour(INDArray iNDArray) {
        Collection<String> nearestNeighbours = getNearestNeighbours(iNDArray, 1);
        if (nearestNeighbours.isEmpty()) {
            return null;
        }
        return nearestNeighbours.iterator().next();
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(iNDArray.dup(), 0, false);
        if (sortWithIndices[0].sumNumber().doubleValue() == 0.0d) {
            this.log.warn("NearestNeighbour on zero vector - please check vector alignment!");
        }
        INDArray iNDArray2 = sortWithIndices[0];
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            if (sortWithIndices[1].getDouble(i2) > 0.0d) {
                arrayList.add(getWord(iNDArray2.getInt(new int[]{i2})));
            }
        }
        return arrayList;
    }

    public boolean keepWord(String str) {
        return Math.random() < samplingRate(str);
    }

    public INDArray subsample(INDArray iNDArray) {
        INDArray dup = iNDArray.dup();
        for (int i = 0; i < iNDArray.length(); i++) {
            if (iNDArray.getDouble(i) > 0.0d && !keepWord(getWord(i))) {
                dup.putScalar(i, 0.0d);
            }
        }
        return dup;
    }

    protected double samplingRate(String str) {
        double probability = getProbability(str);
        return (Math.sqrt(probability / 0.001d) + 1.0d) * (0.001d / probability);
    }

    protected double samplingRate(double d) {
        return 0.001d / (0.001d + d);
    }

    @JsonIgnore
    public INDArray subsampleWeights() {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        for (int i = 0; i < getEmbeddingVectorSize(); i++) {
            zeros.put(i, 0, Double.valueOf(samplingRate(getProbability(getWord(i)))));
        }
        return zeros.transpose();
    }
}
