package de.datexis.encoder;

import com.google.common.collect.Lists;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.annotator.IComponent;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:de/datexis/encoder/Encoder.class */
public abstract class Encoder extends AnnotatorComponent implements IEncoder, IComponent {
    public Encoder() {
        this("");
    }

    public Encoder(String str) {
        super(false);
        this.id = str;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray create = Nd4j.create(new long[]{getEmbeddingVectorSize(), 1});
        int i = 0;
        Iterator<? extends Span> it = iterable.iterator();
        while (it.hasNext()) {
            INDArray encode = encode(it.next().getText());
            if (encode != null) {
                create.addi(encode);
                i++;
            }
        }
        return create.divi(Integer.valueOf(i));
    }

    public void encodeEach(Document document, Class<? extends Span> cls) {
        if (cls == Token.class) {
            document.streamTokens().forEach(token -> {
                token.putVector((Class<? extends IEncoder>) getClass(), encode(token));
            });
        } else {
            if (cls != Sentence.class) {
                throw new IllegalArgumentException("Cannot encode class " + cls.toString() + " from Document");
            }
            document.streamSentences().forEach(sentence -> {
                sentence.putVector((Class<? extends IEncoder>) getClass(), encode(sentence));
            });
        }
    }

    public INDArray encodeMatrix(List<Document> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(new long[]{list.size(), getEmbeddingVectorSize(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            Document document = list.get(i2);
            List list2 = Collections.EMPTY_LIST;
            if (cls == Token.class) {
                list2 = Lists.newArrayList(document.getTokens());
            } else if (cls == Sentence.class) {
                list2 = Lists.newArrayList(document.getSentences());
            }
            for (int i3 = 0; i3 < list2.size() && i3 < i; i3++) {
                zeros.getRow(i2).getColumn(i3).assign(encode((Span) list2.get(i3)));
            }
        }
        return zeros;
    }

    public void encodeEach(Collection<Document> collection, Class<? extends Span> cls) {
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            encodeEach(it.next(), cls);
        }
    }

    public void encodeEach(Sentence sentence, Class<? extends Span> cls) {
        if (cls != Token.class) {
            throw new IllegalArgumentException("Cannot encode class " + cls.toString() + " from Sentence");
        }
        sentence.streamTokens().forEach(token -> {
            token.putVector((Class<? extends IEncoder>) getClass(), encode(token));
        });
    }

    public abstract void trainModel(Collection<Document> collection);

    public void trainModel(Stream<Document> stream) {
        trainModel((Collection<Document>) stream.collect(Collectors.toList()));
    }
}
