/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.nlp.vectorizer;

import java.util.ArrayList;
import java.util.Arrays;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.impl.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.vector.Vectorizer;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.vectorizer.AbstractTfidfVectorizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Counter;

public class TfidfVectorizer
extends AbstractTfidfVectorizer<INDArray> {
    public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf";
    protected boolean smooth_idf;

    @Override
    public INDArray createVector(Object[] args) {
        Counter docFrequencies = (Counter)args[0];
        double[] vector = new double[this.cache.vocabWords().size()];
        for (int i = 0; i < this.cache.vocabWords().size(); ++i) {
            String word = this.cache.wordAt(i);
            double freq = docFrequencies.getCount((Object)word);
            vector[i] = this.cache.tfidf(word, freq, this.smooth_idf);
        }
        return Nd4j.create((double[])vector);
    }

    @Override
    public INDArray fitTransform(RecordReader reader) {
        return this.fitTransform(reader, null);
    }

    public INDArray fitTransform(RecordReader reader, Vectorizer.RecordCallBack callBack) {
        final ArrayList records = new ArrayList();
        this.fit(reader, new Vectorizer.RecordCallBack(){

            public void onRecord(org.datavec.api.records.Record record) {
                records.add(record);
            }
        });
        if (records.isEmpty()) {
            throw new IllegalStateException("No records found!");
        }
        INDArray ret = Nd4j.create((int)records.size(), (int)this.cache.vocabWords().size());
        int i = 0;
        for (org.datavec.api.records.Record record : records) {
            INDArray transformed = this.transform(record);
            Record transformedRecord = new Record(Arrays.asList(new NDArrayWritable(transformed), (Writable)record.getRecord().get(record.getRecord().size() - 1)), (RecordMetaData)new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass()));
            ret.putRow((long)i++, transformed);
            if (callBack == null) continue;
            callBack.onRecord((org.datavec.api.records.Record)transformedRecord);
        }
        return ret;
    }

    @Override
    public INDArray transform(org.datavec.api.records.Record record) {
        Counter<String> wordFrequencies = this.wordFrequenciesForRecord(record.getRecord());
        return this.createVector(new Object[]{wordFrequencies});
    }

    @Override
    public void initialize(Configuration conf) {
        super.initialize(conf);
        this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true);
    }
}

