package org.datavec.nlp.transforms;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.util.MathUtils;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonInclude(JsonInclude.Include.NON_NULL)
/* loaded from: input_file:org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.class */
public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform {
    private String newColumName;
    private Map<String, Integer> wordIndexMap;
    private Map<String, Double> weightMap;
    private boolean exceptionOnUnknown;
    private String tokenizerFactoryClass;
    private String preprocessorClass;
    private TokenizerFactory tokenizerFactory;

    @JsonCreator
    public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String str, @JsonProperty("newColumnName") String str2, @JsonProperty("wordIndexMap") Map<String, Integer> map, @JsonProperty("idfMap") Map<String, Double> map2, @JsonProperty("exceptionOnUnknown") boolean z, @JsonProperty("tokenizerFactoryClass") String str3, @JsonProperty("preprocessorClass") String str4) {
        super(str);
        this.newColumName = str2;
        this.wordIndexMap = map;
        this.exceptionOnUnknown = z;
        this.weightMap = map2;
        this.tokenizerFactoryClass = str3;
        this.preprocessorClass = str4;
        if (this.tokenizerFactoryClass == null) {
            this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName();
        }
        try {
            this.tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance();
            if (str4 != null) {
                try {
                    this.tokenizerFactory.setTokenPreProcessor((TokenPreProcess) Class.forName(this.preprocessorClass).newInstance());
                } catch (Exception e) {
                    throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
                }
            }
        } catch (Exception e2) {
            throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?");
        }
    }

    public List<Writable> map(List<Writable> list) {
        Text text = list.get(this.inputSchema.getIndexOfColumn(this.columnName));
        ArrayList arrayList = new ArrayList(list);
        arrayList.set(this.inputSchema.getIndexOfColumn(this.columnName), new NDArrayWritable(convert(text.toString())));
        return arrayList;
    }

    public Object map(Object obj) {
        return convert(obj.toString());
    }

    public Object mapSequence(Object obj) {
        return convert(obj.toString());
    }

    public Schema transform(Schema schema) {
        Schema.Builder builder = new Schema.Builder();
        for (int i = 0; i < schema.numColumns(); i++) {
            if (schema.getName(i).equals(this.columnName)) {
                builder.addColumnNDArray(this.newColumName, new long[]{1, this.wordIndexMap.size()});
            } else {
                builder.addColumn(schema.getMetaData(i));
            }
        }
        return builder.build();
    }

    public INDArray convert(String str) {
        List<String> tokens = this.tokenizerFactory.create(str).getTokens();
        INDArray create = Nd4j.create(1, this.wordIndexMap.size());
        Counter counter = new Counter();
        for (int i = 0; i < tokens.size(); i++) {
            counter.incrementCount(tokens.get(i), 1.0d);
        }
        for (int i2 = 0; i2 < tokens.size(); i2++) {
            if (this.wordIndexMap.containsKey(tokens.get(i2))) {
                create.putScalar(this.wordIndexMap.get(tokens.get(i2)).intValue(), tfidfWord(tokens.get(i2), (int) counter.getCount(tokens.get(i2)), tokens.size()));
            }
        }
        return create;
    }

    public double tfidfWord(String str, long j, long j2) {
        return MathUtils.tfidf(tfForWord(j, j2), idfForWord(str));
    }

    private double tfForWord(long j, long j2) {
        return j;
    }

    private double idfForWord(String str) {
        if (this.weightMap.containsKey(str)) {
            return this.weightMap.get(str).doubleValue();
        }
        return 0.0d;
    }

    public ColumnMetaData getNewColumnMetaData(String str, ColumnMetaData columnMetaData) {
        return new NDArrayMetaData(outputColumnName(), new long[]{1, this.wordIndexMap.size()});
    }

    public String outputColumnName() {
        return this.newColumName;
    }

    public String[] outputColumnNames() {
        return new String[]{this.newColumName};
    }

    public String[] columnNames() {
        return new String[]{columnName()};
    }

    public String columnName() {
        return this.columnName;
    }

    public Writable map(Writable writable) {
        return new NDArrayWritable(convert(writable.toString()));
    }

    public String getNewColumName() {
        return this.newColumName;
    }

    public Map<String, Integer> getWordIndexMap() {
        return this.wordIndexMap;
    }

    public Map<String, Double> getWeightMap() {
        return this.weightMap;
    }

    public boolean isExceptionOnUnknown() {
        return this.exceptionOnUnknown;
    }

    public String getTokenizerFactoryClass() {
        return this.tokenizerFactoryClass;
    }

    public String getPreprocessorClass() {
        return this.preprocessorClass;
    }

    public TokenizerFactory getTokenizerFactory() {
        return this.tokenizerFactory;
    }

    public void setNewColumName(String str) {
        this.newColumName = str;
    }

    public void setWordIndexMap(Map<String, Integer> map) {
        this.wordIndexMap = map;
    }

    public void setWeightMap(Map<String, Double> map) {
        this.weightMap = map;
    }

    public void setExceptionOnUnknown(boolean z) {
        this.exceptionOnUnknown = z;
    }

    public void setTokenizerFactoryClass(String str) {
        this.tokenizerFactoryClass = str;
    }

    public void setPreprocessorClass(String str) {
        this.preprocessorClass = str;
    }

    public void setTokenizerFactory(TokenizerFactory tokenizerFactory) {
        this.tokenizerFactory = tokenizerFactory;
    }

    public String toString() {
        return "TokenizerBagOfWordsTermSequenceIndexTransform(newColumName=" + getNewColumName() + ", wordIndexMap=" + getWordIndexMap() + ", weightMap=" + getWeightMap() + ", exceptionOnUnknown=" + isExceptionOnUnknown() + ", tokenizerFactoryClass=" + getTokenizerFactoryClass() + ", preprocessorClass=" + getPreprocessorClass() + ", tokenizerFactory=" + getTokenizerFactory() + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof TokenizerBagOfWordsTermSequenceIndexTransform)) {
            return false;
        }
        TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = (TokenizerBagOfWordsTermSequenceIndexTransform) obj;
        if (!tokenizerBagOfWordsTermSequenceIndexTransform.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        String newColumName = getNewColumName();
        String newColumName2 = tokenizerBagOfWordsTermSequenceIndexTransform.getNewColumName();
        if (newColumName == null) {
            if (newColumName2 != null) {
                return false;
            }
        } else if (!newColumName.equals(newColumName2)) {
            return false;
        }
        Map<String, Integer> wordIndexMap = getWordIndexMap();
        Map<String, Integer> wordIndexMap2 = tokenizerBagOfWordsTermSequenceIndexTransform.getWordIndexMap();
        if (wordIndexMap == null) {
            if (wordIndexMap2 != null) {
                return false;
            }
        } else if (!wordIndexMap.equals(wordIndexMap2)) {
            return false;
        }
        Map<String, Double> weightMap = getWeightMap();
        Map<String, Double> weightMap2 = tokenizerBagOfWordsTermSequenceIndexTransform.getWeightMap();
        if (weightMap == null) {
            if (weightMap2 != null) {
                return false;
            }
        } else if (!weightMap.equals(weightMap2)) {
            return false;
        }
        if (isExceptionOnUnknown() != tokenizerBagOfWordsTermSequenceIndexTransform.isExceptionOnUnknown()) {
            return false;
        }
        String tokenizerFactoryClass = getTokenizerFactoryClass();
        String tokenizerFactoryClass2 = tokenizerBagOfWordsTermSequenceIndexTransform.getTokenizerFactoryClass();
        if (tokenizerFactoryClass == null) {
            if (tokenizerFactoryClass2 != null) {
                return false;
            }
        } else if (!tokenizerFactoryClass.equals(tokenizerFactoryClass2)) {
            return false;
        }
        String preprocessorClass = getPreprocessorClass();
        String preprocessorClass2 = tokenizerBagOfWordsTermSequenceIndexTransform.getPreprocessorClass();
        return preprocessorClass == null ? preprocessorClass2 == null : preprocessorClass.equals(preprocessorClass2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof TokenizerBagOfWordsTermSequenceIndexTransform;
    }

    public int hashCode() {
        int hashCode = super.hashCode();
        String newColumName = getNewColumName();
        int hashCode2 = (hashCode * 59) + (newColumName == null ? 43 : newColumName.hashCode());
        Map<String, Integer> wordIndexMap = getWordIndexMap();
        int hashCode3 = (hashCode2 * 59) + (wordIndexMap == null ? 43 : wordIndexMap.hashCode());
        Map<String, Double> weightMap = getWeightMap();
        int hashCode4 = (((hashCode3 * 59) + (weightMap == null ? 43 : weightMap.hashCode())) * 59) + (isExceptionOnUnknown() ? 79 : 97);
        String tokenizerFactoryClass = getTokenizerFactoryClass();
        int hashCode5 = (hashCode4 * 59) + (tokenizerFactoryClass == null ? 43 : tokenizerFactoryClass.hashCode());
        String preprocessorClass = getPreprocessorClass();
        return (hashCode5 * 59) + (preprocessorClass == null ? 43 : preprocessorClass.hashCode());
    }
}
