package org.tensorflow.internal.types.registry;

import java.util.HashMap;
import java.util.Map;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TString;
import org.tensorflow.types.TUint16;
import org.tensorflow.types.TUint8;
import org.tensorflow.types.annotation.TensorType;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/internal/types/registry/TensorTypeRegistry.class */
public final class TensorTypeRegistry {
    private static final Map<Integer, TensorTypeInfo<?>> TYPES_BY_CODE = new HashMap();
    private static final Map<Class<? extends TType>, TensorTypeInfo<?>> TYPES_BY_CLASS = new HashMap();

    public static <T extends TType> TensorTypeInfo<T> find(DataType dataType) {
        TensorTypeInfo<T> tensorTypeInfo = (TensorTypeInfo) TYPES_BY_CODE.get(Integer.valueOf(dataType.getNumber()));
        if (tensorTypeInfo == null) {
            throw new IllegalArgumentException("No tensor type has been registered for data type " + dataType);
        }
        return tensorTypeInfo;
    }

    public static <T extends TType> TensorTypeInfo<T> find(Class<T> cls) {
        TensorTypeInfo<T> tensorTypeInfo = (TensorTypeInfo) TYPES_BY_CLASS.get(cls);
        if (tensorTypeInfo == null) {
            throw new IllegalArgumentException("Class \"" + cls.getName() + "\" is not registered as a tensor type");
        }
        return tensorTypeInfo;
    }

    private static <T extends TType> void register(Class<T> cls) {
        TensorType tensorType = (TensorType) cls.getDeclaredAnnotation(TensorType.class);
        if (tensorType == null) {
            throw new IllegalArgumentException("Class \"" + cls.getName() + "\" must be annotated with @TensorType to be registered as a tensor type");
        }
        try {
            TensorTypeInfo<?> tensorTypeInfo = new TensorTypeInfo<>(cls, tensorType.dataType(), tensorType.byteSize(), tensorType.mapperClass().newInstance());
            TYPES_BY_CLASS.put(cls, tensorTypeInfo);
            TYPES_BY_CODE.put(Integer.valueOf(tensorTypeInfo.dataType().getNumber()), tensorTypeInfo);
            TYPES_BY_CODE.put(Integer.valueOf(tensorTypeInfo.dataType().getNumber() + 100), tensorTypeInfo);
        } catch (ReflectiveOperationException e) {
            throw new IllegalArgumentException("Class \"" + cls.getName() + "\" must have a public parameter-less constructor to be used as a tensor mapper");
        }
    }

    static {
        register(TBool.class);
        register(TFloat64.class);
        register(TFloat32.class);
        register(TFloat16.class);
        register(TInt32.class);
        register(TInt64.class);
        register(TString.class);
        register(TUint8.class);
        register(TUint16.class);
        register(TBfloat16.class);
    }
}
