package org.deeplearning4j.bagofwords.vectorizer;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.documentiterator.DocumentIterator;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.class */
public class TfidfVectorizer extends BaseTextVectorizer implements Serializable {

    /* loaded from: input_file:org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer$Builder.class */
    public static class Builder extends org.deeplearning4j.bagofwords.vectorizer.Builder {
        @Override // org.deeplearning4j.bagofwords.vectorizer.Builder
        public TextVectorizer build() {
            return new TfidfVectorizer(this.cache, this.tokenizerFactory, this.stopWords, this.layerSize, this.minWordFrequency, this.docIter, this.sentenceIterator, this.labels, this.index, this.batchSize, this.sample, this.stem);
        }
    }

    public TfidfVectorizer() {
    }

    protected TfidfVectorizer(VocabCache vocabCache, TokenizerFactory tokenizerFactory, List<String> list, int i, int i2, DocumentIterator documentIterator, SentenceIterator sentenceIterator, List<String> list2, InvertedIndex invertedIndex, int i3, double d, boolean z) {
        super(vocabCache, tokenizerFactory, list, i, i2, documentIterator, sentenceIterator, list2, invertedIndex, i3, d, z);
    }

    private double tfidfWord(String str) {
        return MathUtils.tfidf(tfForWord(str), idfForWord(str));
    }

    private double tfForWord(String str) {
        return MathUtils.tf(this.cache.wordFrequency(str));
    }

    private double idfForWord(String str) {
        return MathUtils.idf(this.cache.totalNumberOfDocs(), this.cache.docAppearedIn(str));
    }

    private INDArray tfidfForInput(String str) {
        INDArray create = Nd4j.create(1, this.cache.numWords());
        List<String> tokens = this.tokenizerFactory.create(str).getTokens();
        for (int i = 0; i < tokens.size(); i++) {
            int indexOf = this.cache.indexOf(tokens.get(i));
            if (indexOf >= 0) {
                create.putScalar(indexOf, tfidfWord(tokens.get(i)));
            }
        }
        return create;
    }

    private INDArray tfidfForInput(InputStream inputStream) {
        try {
            return tfidfForInput(new String(IOUtils.toByteArray(inputStream)));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(InputStream inputStream, String str) {
        return new DataSet(tfidfForInput(inputStream), FeatureUtil.toOutcomeVector(this.labels.indexOf(str), this.labels.size()));
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(String str, String str2) {
        return new DataSet(tfidfForInput(str), FeatureUtil.toOutcomeVector(this.labels.indexOf(str2), this.labels.size()));
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public DataSet vectorize(File file, String str) {
        try {
            return vectorize(FileUtils.readFileToString(file), str);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.bagofwords.vectorizer.TextVectorizer
    public INDArray transform(String str) {
        return tfidfForInput(str);
    }

    public DataSet vectorize() {
        return null;
    }
}
