package org.deeplearning4j.models.word2vec.iterator;

import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.scaleout.api.statetracker.NewUpdateListener;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecResult;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecWork;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/iterator/Word2VecJobIterator.class */
public class Word2VecJobIterator implements JobIterator {
    private Iterator<List<List<VocabWord>>> sentenceIterator;
    private WeightLookupTable table;
    private VocabCache cache;
    private int batchSize;

    public Word2VecJobIterator(Iterator<List<List<VocabWord>>> it, WeightLookupTable weightLookupTable, VocabCache vocabCache, StateTracker stateTracker, int i) {
        this.batchSize = 100;
        this.sentenceIterator = it;
        this.table = weightLookupTable;
        this.cache = vocabCache;
        addListener(stateTracker);
        this.batchSize = i;
    }

    public Word2VecJobIterator(TextVectorizer textVectorizer, WeightLookupTable weightLookupTable, VocabCache vocabCache, StateTracker stateTracker, int i) {
        this.batchSize = 100;
        this.sentenceIterator = textVectorizer.index().batchIter(i);
        this.cache = vocabCache;
        this.table = weightLookupTable;
        addListener(stateTracker);
        this.batchSize = i;
    }

    public Word2VecJobIterator(Iterator<List<List<VocabWord>>> it, WeightLookupTable weightLookupTable, VocabCache vocabCache, StateTracker stateTracker) {
        this.batchSize = 100;
        this.sentenceIterator = it;
        this.table = weightLookupTable;
        this.cache = vocabCache;
        addListener(stateTracker);
    }

    public Word2VecJobIterator(TextVectorizer textVectorizer, WeightLookupTable weightLookupTable, VocabCache vocabCache, StateTracker stateTracker) {
        this.batchSize = 100;
        this.sentenceIterator = textVectorizer.index().batchIter(this.batchSize);
        this.cache = vocabCache;
        this.table = weightLookupTable;
        addListener(stateTracker);
    }

    public Word2VecJobIterator(InvertedIndex invertedIndex, WeightLookupTable weightLookupTable, VocabCache vocabCache, StateTracker stateTracker, int i) {
        this.batchSize = 100;
        this.sentenceIterator = invertedIndex.batchIter(i);
        this.cache = vocabCache;
        this.table = weightLookupTable;
        this.batchSize = i;
        addListener(stateTracker);
    }

    private void addListener(StateTracker stateTracker) {
        stateTracker.addUpdateListener(new NewUpdateListener() { // from class: org.deeplearning4j.models.word2vec.iterator.Word2VecJobIterator.1
            public void onUpdate(Serializable serializable) {
                Collection<Word2VecResult> collection = (Collection) ((Job) serializable).getResult();
                if (collection == null || collection.isEmpty()) {
                    return;
                }
                InMemoryLookupTable inMemoryLookupTable = Word2VecJobIterator.this.table;
                for (Word2VecResult word2VecResult : collection) {
                    for (String str : word2VecResult.getSyn0Change().keySet()) {
                        inMemoryLookupTable.getSyn0().putRow(Word2VecJobIterator.this.cache.indexOf(str), (INDArray) word2VecResult.getSyn0Change().get(str));
                        inMemoryLookupTable.getSyn1().putRow(Word2VecJobIterator.this.cache.indexOf(str), (INDArray) word2VecResult.getSyn1Change().get(str));
                        if (inMemoryLookupTable.getSyn1Neg() != null) {
                            inMemoryLookupTable.getSyn1Neg().putRow(Word2VecJobIterator.this.cache.indexOf(str), (INDArray) word2VecResult.getNegativeChange().get(str));
                        }
                    }
                }
            }
        });
    }

    private Word2VecWork create(List<List<VocabWord>> list) {
        if (this.cache == null) {
            throw new IllegalStateException("Unable to create work; no vocab found");
        }
        if (this.table == null) {
            throw new IllegalStateException("Unable to create work; no table found");
        }
        if (list == null) {
            throw new IllegalArgumentException("Unable to create work from null sentence");
        }
        return new Word2VecWork(this.table, this.cache, list);
    }

    public Job next(String str) {
        return new Job(create(this.sentenceIterator.next()), str);
    }

    public Job next() {
        return new Job(create(this.sentenceIterator.next()), "");
    }

    public boolean hasNext() {
        return this.sentenceIterator.hasNext();
    }

    public void reset() {
    }
}
