package de.datexis.index.encoder;

import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.index.ArticleRef;
import de.datexis.index.WikiDataArticle;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.io.IOException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/index/encoder/EntityEncoder.class */
public class EntityEncoder extends Encoder {
    protected static final Logger log = LoggerFactory.getLogger(EntityEncoder.class);
    protected ParagraphVectors parvec;
    protected Strategy strategy;

    /* loaded from: input_file:de/datexis/index/encoder/EntityEncoder$Strategy.class */
    public enum Strategy {
        NAME,
        NAME_CONTEXT
    }

    public EntityEncoder(Resource resource, Strategy strategy) throws IOException {
        loadModel(resource);
        this.strategy = strategy;
    }

    public void loadModel(Resource resource) throws IOException {
        log.info("loading paragraph vectors...");
        this.parvec = WordVectorSerializer.readParagraphVectors(resource.getInputStream());
        DefaultTokenizerFactory defaultTokenizerFactory = new DefaultTokenizerFactory();
        defaultTokenizerFactory.setTokenPreProcessor(new MinimalLowercasePreprocessor());
        this.parvec.setTokenizerFactory(defaultTokenizerFactory);
        log.info("loaded " + this.parvec.getLabelsSource().getLabels().size() + " paragraph labels with size " + this.parvec.getLayerSize());
    }

    public long getEmbeddingVectorSize() {
        if (this.strategy.equals(Strategy.NAME)) {
            return this.parvec.getLayerSize();
        }
        if (this.strategy.equals(Strategy.NAME_CONTEXT)) {
            return this.parvec.getLayerSize() * 2;
        }
        throw new IllegalArgumentException("invalid strategy");
    }

    public INDArray encodeEntity(WikiDataArticle wikiDataArticle) {
        return encodeEntity(wikiDataArticle.getId(), wikiDataArticle.getTitle(), wikiDataArticle.getDescription());
    }

    public INDArray encodeEntity(ArticleRef articleRef) {
        return encodeEntity(articleRef.getId(), articleRef.getTitle(), articleRef.getDescription());
    }

    private INDArray encodeEntity(String str, String str2, String str3) {
        INDArray encodeID = encodeID(str, str2);
        if (this.strategy.equals(Strategy.NAME)) {
            return encodeID;
        }
        if (!this.strategy.equals(Strategy.NAME_CONTEXT)) {
            throw new IllegalArgumentException("invalid strategy");
        }
        String str4 = str2;
        if (str3 != null) {
            str4 = str4 + " " + str3;
        }
        INDArray encode = encode(str4);
        if (encode.maxNumber().doubleValue() == 0.0d) {
            encode = encodeID;
        }
        return Nd4j.hstack(new INDArray[]{encodeID, encode});
    }

    public INDArray encodeID(String str, String str2) {
        try {
            return normalize(this.parvec.getWordVectorMatrix(str));
        } catch (Exception e) {
            return null;
        }
    }

    public INDArray encodeMention(String str, String str2) {
        INDArray encode = encode(str);
        if (this.strategy.equals(Strategy.NAME)) {
            return encode;
        }
        if (this.strategy.equals(Strategy.NAME_CONTEXT)) {
            return Nd4j.hstack(new INDArray[]{encode, encode(str2)});
        }
        throw new IllegalArgumentException("invalid strategy");
    }

    public INDArray encode(Span span) {
        return encode(span.getText());
    }

    public INDArray encode(String str) {
        try {
            return normalize(this.parvec.inferVector(str));
        } catch (Exception e) {
            return Nd4j.zeros(this.parvec.getLayerSize());
        }
    }

    public void encodeEach(Document document, Annotation.Source source, Class<? extends Annotation> cls) {
        document.streamAnnotations(source, cls).forEach(annotation -> {
            annotation.putVector(EntityEncoder.class, encodeMention(annotation.getText(), ((Sentence) document.getSentenceAtPosition(annotation.getBegin()).get()).toTokenizedString()));
        });
    }

    private INDArray normalize(INDArray iNDArray) {
        if (iNDArray != null) {
            return Transforms.unitVec(iNDArray);
        }
        return null;
    }

    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void saveModel(Resource resource, String str) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }
}
