package org.datavec.nlp.vectorizer;

import java.util.ArrayList;
import java.util.Arrays;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Counter;

/* loaded from: input_file:org/datavec/nlp/vectorizer/TfidfVectorizer.class */
public class TfidfVectorizer extends AbstractTfidfVectorizer<INDArray> {
    public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf";
    protected boolean smooth_idf;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.datavec.nlp.vectorizer.AbstractTfidfVectorizer
    public INDArray createVector(Object[] objArr) {
        Counter counter = (Counter) objArr[0];
        double[] dArr = new double[this.cache.vocabWords().size()];
        for (int i = 0; i < this.cache.vocabWords().size(); i++) {
            String wordAt = this.cache.wordAt(i);
            dArr[i] = this.cache.tfidf(wordAt, counter.getCount(wordAt), this.smooth_idf);
        }
        return Nd4j.create(dArr);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.datavec.nlp.vectorizer.AbstractTfidfVectorizer
    public INDArray fitTransform(RecordReader recordReader) {
        return m6fitTransform(recordReader, (Vectorizer.RecordCallBack) null);
    }

    /* renamed from: fitTransform, reason: merged with bridge method [inline-methods] */
    public INDArray m6fitTransform(RecordReader recordReader, Vectorizer.RecordCallBack recordCallBack) {
        final ArrayList<Record> arrayList = new ArrayList();
        fit(recordReader, new Vectorizer.RecordCallBack() { // from class: org.datavec.nlp.vectorizer.TfidfVectorizer.1
            public void onRecord(Record record) {
                arrayList.add(record);
            }
        });
        if (arrayList.isEmpty()) {
            throw new IllegalStateException("No records found!");
        }
        INDArray create = Nd4j.create(arrayList.size(), this.cache.vocabWords().size());
        int i = 0;
        for (Record record : arrayList) {
            INDArray transform = transform(record);
            org.datavec.api.records.impl.Record record2 = new org.datavec.api.records.impl.Record(Arrays.asList(new NDArrayWritable(transform), (Writable) record.getRecord().get(record.getRecord().size() - 1)), new RecordMetaDataURI(record.getMetaData().getURI(), recordReader.getClass()));
            int i2 = i;
            i++;
            create.putRow(i2, transform);
            if (recordCallBack != null) {
                recordCallBack.onRecord(record2);
            }
        }
        return create;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.datavec.nlp.vectorizer.AbstractTfidfVectorizer
    public INDArray transform(Record record) {
        return createVector(new Object[]{wordFrequenciesForRecord(record.getRecord())});
    }

    @Override // org.datavec.nlp.vectorizer.TextVectorizer
    public void initialize(Configuration configuration) {
        super.initialize(configuration);
        this.smooth_idf = configuration.getBoolean(SMOOTH_IDF, true);
    }
}
