package org.deeplearning4j.models.word2vec.iterator;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetPreProcessor;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.text.inputsanitation.InputHomogenization;
import org.deeplearning4j.text.movingwindow.Window;
import org.deeplearning4j.text.movingwindow.WindowConverter;
import org.deeplearning4j.text.movingwindow.Windows;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIterator.class */
public class Word2VecDataSetIterator implements DataSetIterator {
    private Word2Vec vec;
    private LabelAwareSentenceIterator iter;
    private List<Window> cachedWindow;
    private List<String> labels;
    private boolean homogenization;
    private boolean addLabels;
    private int batch;
    private DataSet curr;
    private DataSetPreProcessor preProcessor;

    public Word2VecDataSetIterator(Word2Vec word2Vec, LabelAwareSentenceIterator labelAwareSentenceIterator, List<String> list, int i, boolean z, boolean z2) {
        this.homogenization = true;
        this.addLabels = true;
        this.batch = 10;
        this.vec = word2Vec;
        this.iter = labelAwareSentenceIterator;
        this.labels = list;
        this.batch = i;
        this.cachedWindow = new CopyOnWriteArrayList();
        this.addLabels = z2;
        this.homogenization = z;
        if (z2 && z) {
            labelAwareSentenceIterator.setPreProcessor(new SentencePreProcessor() { // from class: org.deeplearning4j.models.word2vec.iterator.Word2VecDataSetIterator.1
                @Override // org.deeplearning4j.text.sentenceiterator.SentencePreProcessor
                public String preProcess(String str) {
                    String currentLabel = Word2VecDataSetIterator.this.iter.currentLabel();
                    return "<" + currentLabel + "> " + new InputHomogenization(str).transform() + " </" + currentLabel + ">";
                }
            });
        } else if (z2) {
            labelAwareSentenceIterator.setPreProcessor(new SentencePreProcessor() { // from class: org.deeplearning4j.models.word2vec.iterator.Word2VecDataSetIterator.2
                @Override // org.deeplearning4j.text.sentenceiterator.SentencePreProcessor
                public String preProcess(String str) {
                    String currentLabel = Word2VecDataSetIterator.this.iter.currentLabel();
                    return "<" + currentLabel + ">" + str + "</" + currentLabel + ">";
                }
            });
        } else if (z) {
            labelAwareSentenceIterator.setPreProcessor(new SentencePreProcessor() { // from class: org.deeplearning4j.models.word2vec.iterator.Word2VecDataSetIterator.3
                @Override // org.deeplearning4j.text.sentenceiterator.SentencePreProcessor
                public String preProcess(String str) {
                    return new InputHomogenization(str).transform();
                }
            });
        }
    }

    public Word2VecDataSetIterator(Word2Vec word2Vec, LabelAwareSentenceIterator labelAwareSentenceIterator, List<String> list) {
        this(word2Vec, labelAwareSentenceIterator, list, 10);
    }

    public Word2VecDataSetIterator(Word2Vec word2Vec, LabelAwareSentenceIterator labelAwareSentenceIterator, List<String> list, int i) {
        this(word2Vec, labelAwareSentenceIterator, list, i, true, true);
    }

    public DataSet next(int i) {
        if (i <= this.cachedWindow.size()) {
            return fromCached(i);
        }
        if (i >= this.cachedWindow.size() && !this.iter.hasNext()) {
            return fromCached(this.cachedWindow.size());
        }
        while (this.cachedWindow.size() < i && this.iter.hasNext()) {
            String nextSentence = this.iter.nextSentence();
            if (!nextSentence.isEmpty()) {
                List<Window> windows = Windows.windows(nextSentence, this.vec.getTokenizerFactory(), this.vec.getWindow());
                if (windows.isEmpty() && !nextSentence.isEmpty()) {
                    throw new IllegalStateException("Empty window on sentence");
                }
                Iterator<Window> it = windows.iterator();
                while (it.hasNext()) {
                    it.next().setLabel(this.iter.currentLabel());
                }
                this.cachedWindow.addAll(windows);
            }
        }
        return fromCached(i);
    }

    private DataSet fromCached(int i) {
        if (this.cachedWindow.isEmpty()) {
            while (this.cachedWindow.size() < i && this.iter.hasNext()) {
                String nextSentence = this.iter.nextSentence();
                if (!nextSentence.isEmpty()) {
                    List<Window> windows = Windows.windows(nextSentence, this.vec.getTokenizerFactory(), this.vec.getWindow());
                    Iterator<Window> it = windows.iterator();
                    while (it.hasNext()) {
                        it.next().setLabel(this.iter.currentLabel());
                    }
                    this.cachedWindow.addAll(windows);
                }
            }
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i && !this.cachedWindow.isEmpty(); i2++) {
            arrayList.add(this.cachedWindow.remove(0));
        }
        if (arrayList.isEmpty()) {
            return null;
        }
        INDArray create = Nd4j.create(i, inputColumns());
        for (int i3 = 0; i3 < create.rows(); i3++) {
            create.putRow(i3, WindowConverter.asExampleMatrix((Window) arrayList.get(i3), this.vec));
        }
        INDArray create2 = Nd4j.create(i, this.labels.size());
        for (int i4 = 0; i4 < create2.rows(); i4++) {
            create2.putRow(i4, FeatureUtil.toOutcomeVector(this.labels.indexOf(((Window) arrayList.get(i4)).getLabel()), this.labels.size()));
        }
        DataSet dataSet = new DataSet(create, create2);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        return this.vec.getLayerSize() * this.vec.getWindow();
    }

    public int totalOutcomes() {
        return this.labels.size();
    }

    public void reset() {
        this.iter.reset();
        this.cachedWindow.clear();
    }

    public int batch() {
        return this.batch;
    }

    public int cursor() {
        return 0;
    }

    public int numExamples() {
        return 0;
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    public boolean hasNext() {
        return this.iter.hasNext() || !this.cachedWindow.isEmpty();
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public DataSet m9next() {
        return next(this.batch);
    }

    public void remove() {
        throw new UnsupportedOperationException();
    }
}
