package de.datexis.encoder.impl;

import de.datexis.common.Resource;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/SentenceEmbeddingEncoder.class */
public class SentenceEmbeddingEncoder extends LookupCacheEncoder {
    private static final TokenPreProcess preprocessor = new MinimalLowercasePreprocessor();
    protected Word2VecEncoder vec;
    protected INDArray principal;
    protected final double alpha = 1.0E-4d;

    public SentenceEmbeddingEncoder() {
        super("EMB");
        this.alpha = 1.0E-4d;
        this.log = LoggerFactory.getLogger(SentenceEmbeddingEncoder.class);
    }

    public SentenceEmbeddingEncoder(String str) {
        super(str);
        this.alpha = 1.0E-4d;
        this.log = LoggerFactory.getLogger(SentenceEmbeddingEncoder.class);
    }

    public static SentenceEmbeddingEncoder create(Resource resource) {
        SentenceEmbeddingEncoder sentenceEmbeddingEncoder = new SentenceEmbeddingEncoder();
        sentenceEmbeddingEncoder.vec = Word2VecEncoder.load(resource);
        return sentenceEmbeddingEncoder;
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return "Simple Sentence Embedding Encoder";
    }

    @Override // de.datexis.encoder.LookupCacheEncoder, de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.vec.getEmbeddingVectorSize();
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.timer.start();
        int i = 0;
        this.totalWords = 0;
        for (Document document : collection) {
            i += document.countSentences();
            Iterator<Token> it = document.getTokens().iterator();
            while (it.hasNext()) {
                String preProcess = preprocessor.preProcess(it.next().getText());
                this.totalWords++;
                if (!preProcess.isEmpty()) {
                    if (this.vocab.containsWord(preProcess)) {
                        this.vocab.incrementWordCounter(preProcess);
                    } else {
                        this.vocab.addWord(preProcess);
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        INDArray zeros = Nd4j.zeros(new long[]{i, getEmbeddingVectorSize()});
        int i2 = 0;
        Iterator<Document> it2 = collection.iterator();
        while (it2.hasNext()) {
            Iterator<Sentence> it3 = it2.next().getSentences().iterator();
            while (it3.hasNext()) {
                int i3 = i2;
                i2++;
                zeros.getRow(i3).assign(weightedSum(it3.next().getTokens(), 1.0E-4d));
            }
        }
        this.principal = PCA.pca_factor(zeros, 1, false);
        appendTrainLog("trained " + this.vocab.numWords() + " words (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    private INDArray weightedSum(Iterable<? extends Span> iterable, double d) {
        int i = 0;
        INDArray create = Nd4j.create(new long[]{getEmbeddingVectorSize(), 1});
        for (Span span : iterable) {
            INDArray encode = this.vec.encode(span.getText());
            encode.muli(Double.valueOf(d / (d + getProbability(span.getText()))));
            create.addi(encode);
            i++;
        }
        return create.divi(Integer.valueOf(i));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public boolean isUnknown(String str) {
        return super.isUnknown(preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public int getIndex(String str) {
        return super.getIndex(preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public int getFrequency(String str) {
        return super.getFrequency(preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public double getProbability(String str) {
        return super.getProbability(preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.Encoder, de.datexis.encoder.IEncoder
    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray weightedSum = weightedSum(iterable, 1.0E-4d);
        INDArray iNDArray = this.principal;
        return weightedSum.subi(iNDArray.mmul(iNDArray.transpose()).mmul(weightedSum));
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        return encode(DocumentFactory.createTokensFromText(str));
    }

    public Set<String> asString(Iterable<Token> iterable) {
        HashSet hashSet = new HashSet();
        for (Token token : iterable) {
            if (!isUnknown(token.getText())) {
                hashSet.add(preprocessor.preProcess(token.getText()));
            }
        }
        return hashSet;
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        Double[] dArr = new Double[(int) iNDArray.length()];
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            dArr[i2] = Double.valueOf(iNDArray.getDouble(i2));
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            int i4 = 0;
            for (int i5 = 0; i5 < iNDArray.length(); i5++) {
                if (dArr[i5].doubleValue() > d) {
                    i4 = i5;
                    d = dArr[i5].doubleValue();
                    dArr[i5] = Double.valueOf(Double.MIN_VALUE);
                }
            }
            arrayList.add(getWord(i4));
        }
        return arrayList;
    }
}
