/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.nlp.transforms;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.util.MathUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
public class TokenizerBagOfWordsTermSequenceIndexTransform
extends BaseColumnTransform {
    private String newColumName;
    private Map<String, Integer> wordIndexMap;
    private Map<String, Double> weightMap;
    private boolean exceptionOnUnknown;
    private String tokenizerFactoryClass;
    private String preprocessorClass;
    private TokenizerFactory tokenizerFactory;

    @JsonCreator
    public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty(value="columnName") String columnName, @JsonProperty(value="newColumnName") String newColumnName, @JsonProperty(value="wordIndexMap") Map<String, Integer> wordIndexMap, @JsonProperty(value="idfMap") Map<String, Double> idfMap, @JsonProperty(value="exceptionOnUnknown") boolean exceptionOnUnknown, @JsonProperty(value="tokenizerFactoryClass") String tokenizerFactoryClass, @JsonProperty(value="preprocessorClass") String preprocessorClass) {
        super(columnName);
        this.newColumName = newColumnName;
        this.wordIndexMap = wordIndexMap;
        this.exceptionOnUnknown = exceptionOnUnknown;
        this.weightMap = idfMap;
        this.tokenizerFactoryClass = tokenizerFactoryClass;
        this.preprocessorClass = preprocessorClass;
        if (this.tokenizerFactoryClass == null) {
            this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
        }
        try {
            this.tokenizerFactory = (TokenizerFactory)Class.forName(this.tokenizerFactoryClass).newInstance();
        }
        catch (Exception e) {
            throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
        }
        if (preprocessorClass != null) {
            try {
                TokenPreProcess tpp = (TokenPreProcess)Class.forName(this.preprocessorClass).newInstance();
                this.tokenizerFactory.setTokenPreProcessor(tpp);
            }
            catch (Exception e) {
                throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
            }
        }
    }

    public List<Writable> map(List<Writable> writables) {
        Text text = (Text)writables.get(this.inputSchema.getIndexOfColumn(this.columnName));
        ArrayList<Writable> ret = new ArrayList<Writable>(writables);
        ret.set(this.inputSchema.getIndexOfColumn(this.columnName), (Writable)new NDArrayWritable(this.convert(text.toString())));
        return ret;
    }

    public Object map(Object input) {
        return this.convert(input.toString());
    }

    public Object mapSequence(Object sequence) {
        return this.convert(sequence.toString());
    }

    public Schema transform(Schema inputSchema) {
        Schema.Builder newSchema = new Schema.Builder();
        for (int i = 0; i < inputSchema.numColumns(); ++i) {
            if (inputSchema.getName(i).equals(this.columnName)) {
                newSchema.addColumnNDArray(this.newColumName, new long[]{1L, this.wordIndexMap.size()});
                continue;
            }
            newSchema.addColumn(inputSchema.getMetaData(i));
        }
        return newSchema.build();
    }

    public INDArray convert(String text) {
        int i;
        Tokenizer tokenizer = this.tokenizerFactory.create(text);
        List<String> tokens = tokenizer.getTokens();
        INDArray create = Nd4j.create((int)1, (int)this.wordIndexMap.size());
        Counter tokenizedCounter = new Counter();
        for (i = 0; i < tokens.size(); ++i) {
            tokenizedCounter.incrementCount((Object)tokens.get(i), 1.0);
        }
        for (i = 0; i < tokens.size(); ++i) {
            if (!this.wordIndexMap.containsKey(tokens.get(i))) continue;
            int idx = this.wordIndexMap.get(tokens.get(i));
            int count = (int)tokenizedCounter.getCount((Object)tokens.get(i));
            double weight = this.tfidfWord(tokens.get(i), count, tokens.size());
            create.putScalar((long)idx, weight);
        }
        return create;
    }

    public double tfidfWord(String word, long wordCount, long documentLength) {
        double tf = this.tfForWord(wordCount, documentLength);
        double idf = this.idfForWord(word);
        return MathUtils.tfidf((double)tf, (double)idf);
    }

    private double tfForWord(long wordCount, long documentLength) {
        return wordCount;
    }

    private double idfForWord(String word) {
        if (this.weightMap.containsKey(word)) {
            return this.weightMap.get(word);
        }
        return 0.0;
    }

    public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) {
        return new NDArrayMetaData(this.outputColumnName(), new long[]{1L, this.wordIndexMap.size()});
    }

    public String outputColumnName() {
        return this.newColumName;
    }

    public String[] outputColumnNames() {
        return new String[]{this.newColumName};
    }

    public String[] columnNames() {
        return new String[]{this.columnName()};
    }

    public String columnName() {
        return this.columnName;
    }

    public Writable map(Writable columnWritable) {
        return new NDArrayWritable(this.convert(columnWritable.toString()));
    }

    public String getNewColumName() {
        return this.newColumName;
    }

    public Map<String, Integer> getWordIndexMap() {
        return this.wordIndexMap;
    }

    public Map<String, Double> getWeightMap() {
        return this.weightMap;
    }

    public boolean isExceptionOnUnknown() {
        return this.exceptionOnUnknown;
    }

    public String getTokenizerFactoryClass() {
        return this.tokenizerFactoryClass;
    }

    public String getPreprocessorClass() {
        return this.preprocessorClass;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setNewColumName(String newColumName) {
        this.newColumName = newColumName;
    }

    public void setWordIndexMap(Map<String, Integer> wordIndexMap) {
        this.wordIndexMap = wordIndexMap;
    }

    public void setWeightMap(Map<String, Double> weightMap) {
        this.weightMap = weightMap;
    }

    public void setExceptionOnUnknown(boolean exceptionOnUnknown) {
        this.exceptionOnUnknown = exceptionOnUnknown;
    }

    public void setTokenizerFactoryClass(String tokenizerFactoryClass) {
        this.tokenizerFactoryClass = tokenizerFactoryClass;
    }

    public void setPreprocessorClass(String preprocessorClass) {
        this.preprocessorClass = preprocessorClass;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public String toString() {
        return "TokenizerBagOfWordsTermSequenceIndexTransform(newColumName=" + this.getNewColumName() + ", wordIndexMap=" + this.getWordIndexMap() + ", weightMap=" + this.getWeightMap() + ", exceptionOnUnknown=" + this.isExceptionOnUnknown() + ", tokenizerFactoryClass=" + this.getTokenizerFactoryClass() + ", preprocessorClass=" + this.getPreprocessorClass() + ", tokenizerFactory=" + this.getTokenizerFactory() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TokenizerBagOfWordsTermSequenceIndexTransform)) {
            return false;
        }
        TokenizerBagOfWordsTermSequenceIndexTransform other = (TokenizerBagOfWordsTermSequenceIndexTransform)((Object)o);
        if (!other.canEqual((Object)this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        String this$newColumName = this.getNewColumName();
        String other$newColumName = other.getNewColumName();
        if (this$newColumName == null ? other$newColumName != null : !this$newColumName.equals(other$newColumName)) {
            return false;
        }
        Map<String, Integer> this$wordIndexMap = this.getWordIndexMap();
        Map<String, Integer> other$wordIndexMap = other.getWordIndexMap();
        if (this$wordIndexMap == null ? other$wordIndexMap != null : !((Object)this$wordIndexMap).equals(other$wordIndexMap)) {
            return false;
        }
        Map<String, Double> this$weightMap = this.getWeightMap();
        Map<String, Double> other$weightMap = other.getWeightMap();
        if (this$weightMap == null ? other$weightMap != null : !((Object)this$weightMap).equals(other$weightMap)) {
            return false;
        }
        if (this.isExceptionOnUnknown() != other.isExceptionOnUnknown()) {
            return false;
        }
        String this$tokenizerFactoryClass = this.getTokenizerFactoryClass();
        String other$tokenizerFactoryClass = other.getTokenizerFactoryClass();
        if (this$tokenizerFactoryClass == null ? other$tokenizerFactoryClass != null : !this$tokenizerFactoryClass.equals(other$tokenizerFactoryClass)) {
            return false;
        }
        String this$preprocessorClass = this.getPreprocessorClass();
        String other$preprocessorClass = other.getPreprocessorClass();
        return !(this$preprocessorClass == null ? other$preprocessorClass != null : !this$preprocessorClass.equals(other$preprocessorClass));
    }

    protected boolean canEqual(Object other) {
        return other instanceof TokenizerBagOfWordsTermSequenceIndexTransform;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        String $newColumName = this.getNewColumName();
        result = result * 59 + ($newColumName == null ? 43 : $newColumName.hashCode());
        Map<String, Integer> $wordIndexMap = this.getWordIndexMap();
        result = result * 59 + ($wordIndexMap == null ? 43 : ((Object)$wordIndexMap).hashCode());
        Map<String, Double> $weightMap = this.getWeightMap();
        result = result * 59 + ($weightMap == null ? 43 : ((Object)$weightMap).hashCode());
        result = result * 59 + (this.isExceptionOnUnknown() ? 79 : 97);
        String $tokenizerFactoryClass = this.getTokenizerFactoryClass();
        result = result * 59 + ($tokenizerFactoryClass == null ? 43 : $tokenizerFactoryClass.hashCode());
        String $preprocessorClass = this.getPreprocessorClass();
        result = result * 59 + ($preprocessorClass == null ? 43 : $preprocessorClass.hashCode());
        return result;
    }
}

