package ai.djl.mxnet.zoo.nlp.qa;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QATranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;
import ai.djl.util.Utils;
import com.google.gson.annotations.SerializedName;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator.class */
public class MxBertQATranslator extends QATranslator {
    private List<String> tokens;
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;
    private int seqLength;

    /* loaded from: input_file:ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator$Builder.class */
    public static class Builder extends QATranslator.BaseBuilder<Builder> {
        private int seqLength;

        public Builder setSeqLength(int i) {
            this.seqLength = i;
            return m3self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: self, reason: merged with bridge method [inline-methods] */
        public Builder m3self() {
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public MxBertQATranslator build() {
            if (this.seqLength == 0) {
                throw new IllegalArgumentException("You must specify a seqLength with value > 0");
            }
            return new MxBertQATranslator(this);
        }
    }

    /* loaded from: input_file:ai/djl/mxnet/zoo/nlp/qa/MxBertQATranslator$VocabParser.class */
    private static final class VocabParser {

        @SerializedName("idx_to_token")
        List<String> idx2token;

        private VocabParser() {
        }

        public static List<String> parseToken(URL url) {
            try {
                BufferedInputStream bufferedInputStream = new BufferedInputStream(url.openStream());
                Throwable th = null;
                try {
                    InputStreamReader inputStreamReader = new InputStreamReader(bufferedInputStream, StandardCharsets.UTF_8);
                    Throwable th2 = null;
                    try {
                        try {
                            List<String> list = ((VocabParser) JsonUtils.GSON.fromJson(inputStreamReader, VocabParser.class)).idx2token;
                            if (inputStreamReader != null) {
                                if (0 != 0) {
                                    try {
                                        inputStreamReader.close();
                                    } catch (Throwable th3) {
                                        th2.addSuppressed(th3);
                                    }
                                } else {
                                    inputStreamReader.close();
                                }
                            }
                            return list;
                        } finally {
                        }
                    } catch (Throwable th4) {
                        if (inputStreamReader != null) {
                            if (th2 != null) {
                                try {
                                    inputStreamReader.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                inputStreamReader.close();
                            }
                        }
                        throw th4;
                    }
                } finally {
                    if (bufferedInputStream != null) {
                        if (0 != 0) {
                            try {
                                bufferedInputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        } else {
                            bufferedInputStream.close();
                        }
                    }
                }
            } catch (IOException e) {
                throw new IllegalArgumentException("Invalid url: " + url, e);
            }
        }
    }

    MxBertQATranslator(Builder builder) {
        super(builder);
        this.seqLength = builder.seqLength;
    }

    public void prepare(TranslatorContext translatorContext) throws IOException {
        this.vocabulary = DefaultVocabulary.builder().addFromCustomizedFile(translatorContext.getModel().getArtifact("vocab.json"), VocabParser::parseToken).optUnknownToken("[UNK]").build();
        this.tokenizer = new BertTokenizer();
    }

    public Batchifier getBatchifier() {
        return null;
    }

    public NDList processInput(TranslatorContext translatorContext, QAInput qAInput) {
        BertToken encode = this.tokenizer.encode(qAInput.getQuestion().toLowerCase(), qAInput.getParagraph().toLowerCase(), this.seqLength);
        this.tokens = encode.getTokens();
        Stream stream = encode.getTokens().stream();
        Vocabulary vocabulary = this.vocabulary;
        vocabulary.getClass();
        float[] floatArray = Utils.toFloatArray((List) stream.map(vocabulary::getIndex).collect(Collectors.toList()));
        float[] floatArray2 = Utils.toFloatArray(encode.getTokenTypes());
        int validLength = encode.getValidLength();
        NDManager nDManager = translatorContext.getNDManager();
        NDArray create = nDManager.create(floatArray);
        create.setName("data0");
        NDArray create2 = nDManager.create(floatArray2);
        create2.setName("data1");
        NDArray create3 = nDManager.create(new float[]{validLength});
        create3.setName("data2");
        return new NDList(new NDArray[]{create, create2, create3});
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public String m2processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDList split = nDList.singletonOrThrow().split(2L, 2);
        return this.tokenizer.tokenToString(this.tokens.subList((int) ((NDArray) split.get(0)).reshape(new Shape(new long[]{1, -1})).argMax(1).getLong(new long[0]), ((int) ((NDArray) split.get(1)).reshape(new Shape(new long[]{1, -1})).argMax(1).getLong(new long[0])) + 1));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configure(map);
        builder.setSeqLength(ArgumentsUtil.intValue(map, "seqLength", 384));
        return builder;
    }
}
