package de.datexis.encoder;

import de.datexis.model.Span;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/EncoderSet.class */
public class EncoderSet implements Iterable<Encoder>, IEncoder {
    protected static final Logger log = LoggerFactory.getLogger(EncoderSet.class);
    protected List<Encoder> encoders;
    protected int size = 0;

    public EncoderSet(Encoder... encoderArr) {
        this.encoders = new ArrayList(encoderArr.length);
        for (Encoder encoder : encoderArr) {
            addEncoder(encoder);
        }
    }

    public final void addEncoder(Encoder encoder) {
        this.encoders.add(encoder);
        if (encoder.getEmbeddingVectorSize() == 0) {
            log.warn("Adding uninitialized Encoder " + encoder.getName());
        }
        this.size = (int) (this.size + encoder.getEmbeddingVectorSize());
    }

    public void updateVectorSize() {
        this.size = 0;
        Iterator<Encoder> it = this.encoders.iterator();
        while (it.hasNext()) {
            this.size = (int) (this.size + it.next().getEmbeddingVectorSize());
        }
    }

    @Override // de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.size;
    }

    public Iterable<Encoder> iterable() {
        return this.encoders;
    }

    @Override // java.lang.Iterable
    public Iterator<Encoder> iterator() {
        return this.encoders.iterator();
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        INDArray create = Nd4j.create(new long[]{getEmbeddingVectorSize()});
        int i = 0;
        for (Encoder encoder : this.encoders) {
            create.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i + encoder.getEmbeddingVectorSize())}).assign(encoder.encode(str));
            i = (int) (i + encoder.getEmbeddingVectorSize());
        }
        return create;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray create = Nd4j.create(new long[]{getEmbeddingVectorSize()});
        int i = 0;
        for (Encoder encoder : this.encoders) {
            create.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i + encoder.getEmbeddingVectorSize())}).assign(encoder.encode(iterable));
            i = (int) (i + encoder.getEmbeddingVectorSize());
        }
        return create;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return encode(span.toString());
    }
}
