package org.deeplearning4j.datasets.loader;

import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.bagofwords.vectorizer.BagOfWordsVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TextVectorizer;
import org.deeplearning4j.bagofwords.vectorizer.TfidfVectorizer;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.deeplearning4j.util.ArchiveUtils;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/loader/ReutersNewsGroupsLoader.class */
public class ReutersNewsGroupsLoader extends BaseDataFetcher {
    private TextVectorizer textVectorizer;
    private boolean tfidf;
    public static final String NEWSGROUP_URL = "http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz";
    private File reutersRootDir;
    private static Logger log = LoggerFactory.getLogger(ReutersNewsGroupsLoader.class);
    private DataSet load;

    public ReutersNewsGroupsLoader(boolean z) throws Exception {
        getIfNotExists();
        LabelAwareFileSentenceIterator labelAwareFileSentenceIterator = new LabelAwareFileSentenceIterator(this.reutersRootDir);
        ArrayList arrayList = new ArrayList();
        for (File file : this.reutersRootDir.listFiles()) {
            if (file.isDirectory()) {
                arrayList.add(file.getName());
            }
        }
        UimaTokenizerFactory uimaTokenizerFactory = new UimaTokenizerFactory();
        if (z) {
            this.textVectorizer = new TfidfVectorizer.Builder().iterate(labelAwareFileSentenceIterator).labels(arrayList).tokenize(uimaTokenizerFactory).build();
        } else {
            this.textVectorizer = new BagOfWordsVectorizer.Builder().iterate(labelAwareFileSentenceIterator).labels(arrayList).tokenize(uimaTokenizerFactory).build();
        }
        this.load = this.textVectorizer.vectorize();
    }

    private void getIfNotExists() throws Exception {
        this.reutersRootDir = new File(System.getProperty("user.home") + File.separator + "reuters");
        if (!this.reutersRootDir.exists()) {
            this.reutersRootDir.mkdir();
        } else if (this.reutersRootDir.exists()) {
            return;
        }
        File file = new File(this.reutersRootDir, "20news-18828.tar.gz");
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        FileUtils.copyURLToFile(new URL(NEWSGROUP_URL), file);
        ArchiveUtils.unzipFileTo(file.getAbsolutePath(), this.reutersRootDir.getAbsolutePath());
        file.delete();
        FileUtils.copyDirectory(new File(this.reutersRootDir, "20news-18828"), this.reutersRootDir);
        FileUtils.deleteDirectory(new File(this.reutersRootDir, "20news-18828"));
        if (this.reutersRootDir.listFiles() == null) {
            throw new IllegalStateException("No files found!");
        }
    }

    public void fetch(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i && this.cursor < this.load.numExamples(); i2++) {
            arrayList.add(this.load.get(this.cursor));
            this.cursor++;
        }
        this.curr = DataSet.merge(arrayList);
    }
}
