package org.datavec.nlp.vectorizer;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collection;
import org.datavec.api.berkeley.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.metadata.DefaultVocabCache;
import org.datavec.nlp.metadata.VocabCache;
import org.datavec.nlp.stopwords.StopWords;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;

/* loaded from: input_file:org/datavec/nlp/vectorizer/TextVectorizer.class */
public abstract class TextVectorizer<VECTOR_TYPE> implements Vectorizer<VECTOR_TYPE> {
    protected TokenizerFactory tokenizerFactory;
    protected int minWordFrequency = 0;
    public static final String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency";
    public static final String STOP_WORDS = "org.nd4j.nlp.stopwords";
    public static final String TOKENIZER = "org.datavec.nlp.tokenizerfactory";
    public static final String VOCAB_CACHE = "org.datavec.nlp.vocabcache";
    protected Collection<String> stopWords;
    protected VocabCache cache;

    public void initialize(Configuration configuration) {
        this.tokenizerFactory = createTokenizerFactory(configuration);
        this.minWordFrequency = configuration.getInt(MIN_WORD_FREQUENCY, 5);
        this.stopWords = configuration.getStringCollection(STOP_WORDS);
        if (this.stopWords == null || this.stopWords.isEmpty()) {
            this.stopWords = StopWords.getStopWords();
        }
        try {
            this.cache = (VocabCache) Class.forName(configuration.get(VOCAB_CACHE, DefaultVocabCache.class.getName())).newInstance();
            this.cache.initialize(configuration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void fit(RecordReader recordReader) {
        fit(recordReader, null);
    }

    public void fit(RecordReader recordReader, Vectorizer.RecordCallBack recordCallBack) {
        while (recordReader.hasNext()) {
            Record nextRecord = recordReader.nextRecord();
            Tokenizer create = this.tokenizerFactory.create(toString(nextRecord.getRecord()));
            this.cache.incrementNumDocs(1.0d);
            doWithTokens(create);
            if (recordCallBack != null) {
                recordCallBack.onRecord(nextRecord);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Counter<String> wordFrequenciesForRecord(Collection<Writable> collection) {
        Tokenizer create = this.tokenizerFactory.create(toString(collection));
        Counter<String> counter = new Counter<>();
        while (create.hasMoreTokens()) {
            counter.incrementCount(create.nextToken(), 1.0d);
        }
        return counter;
    }

    protected String toString(Collection<Writable> collection) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        for (Writable writable : collection) {
            if (writable instanceof Text) {
                try {
                    writable.write(dataOutputStream);
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
        return new String(byteArrayOutputStream.toByteArray());
    }

    public abstract void doWithTokens(Tokenizer tokenizer);

    public abstract TokenizerFactory createTokenizerFactory(Configuration configuration);
}
