package sklearn.feature_extraction.text;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.io.CharStreams;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import numpy.core.ScalarUtil;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.ParameterField;
import org.dmg.pmml.TextIndex;
import org.dmg.pmml.TextIndexNormalization;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ObjectFeature;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.StringFeature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.HasSparseOutput;
import sklearn.SkLearnTransformer;
import sklearn2pmml.feature_extraction.text.Matcher;

/* loaded from: input_file:sklearn/feature_extraction/text/CountVectorizer.class */
public class CountVectorizer extends SkLearnTransformer implements HasSparseOutput {
    public static final String TOKEN_PATTERN = "(?u)\\b\\w\\w+\\b";

    public CountVectorizer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return 1;
    }

    @Override // sklearn.Transformer, sklearn.HasType
    public OpType getOpType() {
        return OpType.CATEGORICAL;
    }

    @Override // sklearn.Transformer, sklearn.HasType
    public DataType getDataType() {
        return DataType.STRING;
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        Boolean lowercase = getLowercase();
        Map<String, ?> vocabulary = getVocabulary();
        ClassDictUtil.checkSize(1, new Collection[]{list});
        Feature feature = list.get(0);
        HashBiMap create = HashBiMap.create(vocabulary.size());
        for (Map.Entry<String, ?> entry : vocabulary.entrySet()) {
            create.put(entry.getKey(), ValueUtil.asInteger((Number) ScalarUtil.decode(entry.getValue())));
        }
        BiMap inverse = create.inverse();
        TypeInfo dType = getDType();
        DataType dataType = dType != null ? dType.getDataType() : DataType.DOUBLE;
        if (lowercase.booleanValue()) {
            Apply createApply = ExpressionUtil.createApply("lowercase", new Expression[]{feature.ref()});
            feature = new StringFeature(skLearnEncoder, skLearnEncoder.ensureDerivedField(FieldNameUtil.create("lowercase", new Object[]{feature}), OpType.CATEGORICAL, DataType.STRING, () -> {
                return createApply;
            }));
        }
        DefineFunction encodeDefineFunction = encodeDefineFunction(feature, skLearnEncoder);
        skLearnEncoder.addDefineFunction(encodeDefineFunction);
        ArrayList arrayList = new ArrayList();
        int size = inverse.size();
        for (int i = 0; i < size; i++) {
            String str = (String) inverse.get(Integer.valueOf(i));
            final Apply encodeApply = encodeApply(encodeDefineFunction, feature, i, str);
            arrayList.add(new ObjectFeature(skLearnEncoder, FieldNameUtil.create(functionName(), new Object[]{feature, str}), dataType) { // from class: sklearn.feature_extraction.text.CountVectorizer.1
                public ContinuousFeature toContinuousFeature() {
                    String name = getName();
                    DataType dataType2 = getDataType();
                    Apply apply = encodeApply;
                    return toContinuousFeature(name, dataType2, () -> {
                        return apply;
                    });
                }
            });
        }
        return arrayList;
    }

    public DefineFunction encodeDefineFunction(Feature feature, SkLearnEncoder skLearnEncoder) {
        String formatStopWordsRE;
        String analyzer = getAnalyzer();
        List<String> stopWords = getStopWords();
        Object[] nGramRange = getNGramRange();
        Boolean binary = getBinary();
        Object preprocessor = getPreprocessor();
        String stripAccents = getStripAccents();
        Tokenizer tokenizer = getTokenizer();
        boolean z = -1;
        switch (analyzer.hashCode()) {
            case 3655434:
                if (analyzer.equals("word")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (preprocessor != null) {
                    throw new IllegalArgumentException();
                }
                if (stripAccents != null) {
                    throw new IllegalArgumentException(stripAccents);
                }
                if (tokenizer == null) {
                    tokenizer = new Matcher().setWordRE(getTokenPattern());
                }
                ParameterField parameterField = new ParameterField("document");
                ParameterField parameterField2 = new ParameterField("term");
                TextIndex configure = tokenizer.configure(new TextIndex(parameterField, new FieldRef(parameterField2)).setLocalTermWeights(binary.booleanValue() ? TextIndex.LocalTermWeights.BINARY : null));
                if (stopWords != null && !stopWords.isEmpty() && !Arrays.equals(nGramRange, new Integer[]{1, 1}) && (formatStopWordsRE = tokenizer.formatStopWordsRE(stopWords)) != null) {
                    LinkedHashMap linkedHashMap = new LinkedHashMap();
                    linkedHashMap.put("string", Collections.singletonList(formatStopWordsRE));
                    linkedHashMap.put("stem", Collections.singletonList(" "));
                    linkedHashMap.put("regex", Collections.singletonList("true"));
                    configure.addTextIndexNormalizations(new TextIndexNormalization[]{new TextIndexNormalization(PMMLUtil.createInlineTable(linkedHashMap)).setRecursive(Boolean.TRUE)});
                }
                return new DefineFunction(createFieldName(functionName(), feature), OpType.CONTINUOUS, DataType.INTEGER, (List) null, configure).addParameterFields(new ParameterField[]{parameterField, parameterField2});
            default:
                throw new IllegalArgumentException(analyzer);
        }
    }

    public Apply encodeApply(DefineFunction defineFunction, Feature feature, int i, String str) {
        return ExpressionUtil.createApply(defineFunction, new Expression[]{feature.ref(), ExpressionUtil.createConstant(DataType.STRING, str)});
    }

    public String functionName() {
        return "tf";
    }

    public String getAnalyzer() {
        return getString("analyzer");
    }

    public Boolean getBinary() {
        return getBoolean("binary");
    }

    public TypeInfo getDType() {
        return getDType("dtype", false);
    }

    public Boolean getLowercase() {
        return getBoolean("lowercase");
    }

    public Object[] getNGramRange() {
        return getTuple("ngram_range");
    }

    public Object getPreprocessor() {
        return getOptionalObject("preprocessor");
    }

    @Override // sklearn.HasSparseOutput
    public Boolean getSparseOutput() {
        return Boolean.TRUE;
    }

    public List<String> getStopWords() {
        Object optionalObject = getOptionalObject("stop_words");
        return optionalObject instanceof String ? loadStopWords((String) optionalObject) : (List) optionalObject;
    }

    public String getStripAccents() {
        return getOptionalString("strip_accents");
    }

    public Tokenizer getTokenizer() {
        return (Tokenizer) getOptional("tokenizer", Tokenizer.class);
    }

    public String getTokenPattern() {
        return getString("token_pattern");
    }

    public Map<String, ?> getVocabulary() {
        return getDict("vocabulary_");
    }

    private static List<String> loadStopWords(String str) {
        InputStream resourceAsStream = CountVectorizer.class.getResourceAsStream("/stop_words/" + str + ".txt");
        if (resourceAsStream == null) {
            throw new IllegalArgumentException(str);
        }
        try {
            InputStreamReader inputStreamReader = new InputStreamReader(resourceAsStream, "UTF-8");
            Throwable th = null;
            try {
                try {
                    List<String> readLines = CharStreams.readLines(inputStreamReader);
                    if (inputStreamReader != null) {
                        if (0 != 0) {
                            try {
                                inputStreamReader.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            inputStreamReader.close();
                        }
                    }
                    return readLines;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalArgumentException(str, e);
        }
    }
}
