package ai.djl.pytorch.zoo.nlp.textgeneration;

import ai.djl.Model;
import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:ai/djl/pytorch/zoo/nlp/textgeneration/PtGptTranslatorFactory.class */
public class PtGptTranslatorFactory implements TranslatorFactory {
    private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet();

    public Set<Pair<Type, Type>> getSupportedTypes() {
        return SUPPORTED_TYPES;
    }

    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) {
        if (isSupported(cls, cls2)) {
            return new PtGptTranslator(ArgumentsUtil.longValue(map, "kvDim", 64L).longValue(), ArgumentsUtil.intValue(map, "numAttentionHeads", 12), ArgumentsUtil.intValue(map, "numLayers", 12));
        }
        throw new IllegalArgumentException("Unsupported input/output types.");
    }

    static {
        SUPPORTED_TYPES.add(new Pair<>(NDList.class, CausalLMOutput.class));
    }
}
