package org.tribuo.interop.onnx.extractors;

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.tribuo.interop.onnx.extractors.BERTFeatureExtractor;
import org.tribuo.util.tokens.impl.wordpiece.Wordpiece;
import org.tribuo.util.tokens.impl.wordpiece.WordpieceBasicTokenizer;
import org.tribuo.util.tokens.impl.wordpiece.WordpieceTokenizer;

/* loaded from: input_file:org/tribuo/interop/onnx/extractors/Tokenizer.class */
public class Tokenizer {
    private WordpieceTokenizer tokenizer;
    private int maxLength = 512;
    private Map<String, Integer> tokenIDs;
    private String classificationToken;
    private String separatorToken;
    private String unknownToken;

    public Tokenizer(Path path) throws Exception {
        BERTFeatureExtractor.TokenizerConfig loadTokenizer = BERTFeatureExtractor.loadTokenizer(path);
        this.tokenizer = new WordpieceTokenizer(new Wordpiece(loadTokenizer.tokenIDs.keySet(), loadTokenizer.unknownToken, loadTokenizer.maxInputCharsPerWord), new WordpieceBasicTokenizer(), loadTokenizer.lowercase, loadTokenizer.stripAccents, Collections.emptySet());
        this.tokenIDs = loadTokenizer.tokenIDs;
        this.unknownToken = loadTokenizer.unknownToken;
        this.classificationToken = loadTokenizer.classificationToken;
        this.separatorToken = loadTokenizer.separatorToken;
    }

    public List<String> tokenize(String str) {
        List<String> split = this.tokenizer.split(str);
        if (split.size() > this.maxLength - 2) {
            split = split.subList(0, this.maxLength - 2);
        }
        return split;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public OnnxTensor convertTokens(OrtEnvironment ortEnvironment, List<String> list) throws OrtException {
        long[] jArr = new long[list.size() + 2];
        jArr[0] = this.tokenIDs.get(this.classificationToken).intValue();
        int i = 1;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            if (this.tokenIDs.get(it.next()) == null) {
                jArr[i] = this.tokenIDs.get(this.unknownToken).intValue();
            } else {
                jArr[i] = r0.intValue();
            }
            i++;
        }
        jArr[i] = this.tokenIDs.get(this.separatorToken).intValue();
        return OnnxTensor.createTensor(ortEnvironment, new long[]{jArr});
    }
}
