package org.datavec.nlp.reader;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.datavec.common.RecordConverter;
import org.datavec.nlp.vectorizer.TfidfVectorizer;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/datavec/nlp/reader/TfidfRecordReader.class */
public class TfidfRecordReader extends FileRecordReader {
    private TfidfVectorizer tfidfVectorizer;
    private Collection<Collection<Writable>> records = new ArrayList();
    private List<Integer> recordLabels = new ArrayList();
    private Iterator<Integer> labelIter;
    private Iterator<Collection<Writable>> recordIter;
    private int numFeatures;

    public void initialize(InputSplit inputSplit) throws IOException, InterruptedException {
        initialize(new Configuration(), inputSplit);
    }

    public void initialize(Configuration configuration, InputSplit inputSplit) throws IOException, InterruptedException {
        super.initialize(configuration, inputSplit);
        if (this.tfidfVectorizer != null) {
            this.records = new ArrayList();
            while (hasNext()) {
                Collection<Writable> next = next();
                if (this.appendLabel) {
                    this.recordLabels.add(Integer.valueOf(new IntWritable(getCurrentLabel()).toInt()));
                }
                this.records.add(RecordConverter.toRecord(this.tfidfVectorizer.transform(next)));
            }
            this.labelIter = this.recordLabels.iterator();
            this.recordIter = this.records.iterator();
            return;
        }
        this.tfidfVectorizer = new TfidfVectorizer();
        this.tfidfVectorizer.initialize(configuration);
        INDArray m6fitTransform = this.tfidfVectorizer.m6fitTransform((RecordReader) this, new Vectorizer.RecordCallBack() { // from class: org.datavec.nlp.reader.TfidfRecordReader.1
            public void onRecord(Collection<Writable> collection) {
                Iterator<Writable> it = collection.iterator();
                it.next();
                TfidfRecordReader.this.recordLabels.add(Integer.valueOf(it.next().toInt()));
            }
        });
        this.records.clear();
        for (int i = 0; i < m6fitTransform.rows(); i++) {
            this.records.add(RecordConverter.toRecord(m6fitTransform.getRow(i)));
        }
        this.numFeatures = m6fitTransform.columns();
        this.labelIter = this.recordLabels.iterator();
        this.recordIter = this.records.iterator();
    }

    public Collection<Writable> next() {
        if (this.recordIter == null) {
            return super.next();
        }
        Collection<Writable> next = this.recordIter.next();
        if (this.appendLabel) {
            next.add(new IntWritable(this.labelIter.next().intValue()));
        }
        return next;
    }

    public boolean hasNext() {
        return this.recordIter == null ? super.hasNext() : this.recordIter.hasNext();
    }

    public void close() throws IOException {
    }

    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    public Configuration getConf() {
        return this.conf;
    }

    public TfidfVectorizer getTfidfVectorizer() {
        return this.tfidfVectorizer;
    }

    public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) {
        this.tfidfVectorizer = tfidfVectorizer;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }
}
