package ai.djl.mxnet.zoo.nlp.qa;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:ai/djl/mxnet/zoo/nlp/qa/BertDataParser.class */
public class BertDataParser {
    private static final Gson GSON = new GsonBuilder().create();
    private static final Pattern PATTERN = Pattern.compile("(\\S+?)([.,?!])?(\\s+|$)");

    @SerializedName("token_to_idx")
    private Map<String, Integer> token2idx;

    @SerializedName("idx_to_token")
    private List<String> idx2token;

    public static BertDataParser parse(InputStream inputStream) {
        try {
            InputStreamReader inputStreamReader = new InputStreamReader(inputStream, StandardCharsets.UTF_8);
            Throwable th = null;
            try {
                BertDataParser bertDataParser = (BertDataParser) GSON.fromJson(inputStreamReader, BertDataParser.class);
                if (inputStreamReader != null) {
                    if (0 != 0) {
                        try {
                            inputStreamReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        inputStreamReader.close();
                    }
                }
                return bertDataParser;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    public static List<String> tokenizer(String str) {
        LinkedList linkedList = new LinkedList();
        Matcher matcher = PATTERN.matcher(str);
        while (matcher.find()) {
            linkedList.add(matcher.group(1));
            String group = matcher.group(2);
            if (group != null) {
                linkedList.add(group);
            }
        }
        return linkedList;
    }

    public static <E> List<E> pad(List<E> list, E e, int i) {
        if (list.size() >= i) {
            return list;
        }
        ArrayList arrayList = new ArrayList(i);
        arrayList.addAll(list);
        for (int size = list.size(); size < i; size++) {
            arrayList.add(e);
        }
        return arrayList;
    }

    public static List<Float> getTokenTypes(List<String> list, List<String> list2, int i) {
        List pad = pad(new ArrayList(), Float.valueOf(0.0f), list.size() + 2);
        pad.addAll(pad(new ArrayList(), Float.valueOf(1.0f), list2.size()));
        return pad(pad, Float.valueOf(0.0f), i);
    }

    public static List<String> formTokens(List<String> list, List<String> list2, int i) {
        ArrayList arrayList = new ArrayList(list);
        arrayList.add("[SEP]");
        arrayList.add(0, "[CLS]");
        list2.add("[SEP]");
        arrayList.addAll(list2);
        arrayList.add("[SEP]");
        return pad(arrayList, "[PAD]", i);
    }

    public List<Integer> token2idx(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            if (this.token2idx.containsKey(str)) {
                arrayList.add(this.token2idx.get(str));
            } else {
                arrayList.add(this.token2idx.get("[UNK]"));
            }
        }
        return arrayList;
    }

    public List<String> idx2token(List<Integer> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(this.idx2token.get(it.next().intValue()));
        }
        return arrayList;
    }
}
