package ai.djl.pytorch.zoo.nlp.qa;

import ai.djl.Model;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QaServingTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:ai/djl/pytorch/zoo/nlp/qa/PtBertQATranslatorFactory.class */
public class PtBertQATranslatorFactory implements TranslatorFactory {
    private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet();

    public Set<Pair<Type, Type>> getSupportedTypes() {
        return SUPPORTED_TYPES;
    }

    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) {
        if (!isSupported(cls, cls2)) {
            throw new IllegalArgumentException("Unsupported input/output types.");
        }
        PtBertQATranslator build = PtBertQATranslator.builder(map).build();
        return (cls == Input.class && cls2 == Output.class) ? new QaServingTranslator(build) : build;
    }

    static {
        SUPPORTED_TYPES.add(new Pair<>(QAInput.class, String.class));
        SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
    }
}
