package ai.konduit.serving.pipeline.registry;

import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.format.NDArrayConverter;
import ai.konduit.serving.pipeline.api.format.NDArrayFormat;
import ai.konduit.serving.pipeline.impl.data.ndarray.SerializedNDArray;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.common.primitives.Pair;

/* loaded from: input_file:ai/konduit/serving/pipeline/registry/NDArrayConverterRegistry.class */
public class NDArrayConverterRegistry extends AbstractRegistry<NDArrayConverter> {
    private static final NDArrayConverterRegistry INSTANCE = new NDArrayConverterRegistry();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/konduit/serving/pipeline/registry/NDArrayConverterRegistry$TwoStepNDArrayConverter.class */
    public static class TwoStepNDArrayConverter implements NDArrayConverter {
        private Class<?> cFrom;
        private Class<?> cTo;
        private NDArrayConverter c1;
        private NDArrayConverter c2;

        @Override // ai.konduit.serving.pipeline.api.format.NDArrayConverter
        public boolean canConvert(NDArray nDArray, NDArrayFormat<?> nDArrayFormat) {
            return false;
        }

        @Override // ai.konduit.serving.pipeline.api.format.NDArrayConverter
        public boolean canConvert(NDArray nDArray, Class<?> cls) {
            return this.cFrom.isAssignableFrom(nDArray.get().getClass()) && cls.isAssignableFrom(this.cTo);
        }

        @Override // ai.konduit.serving.pipeline.api.format.NDArrayConverter
        public <T> T convert(NDArray nDArray, NDArrayFormat<T> nDArrayFormat) {
            throw new UnsupportedOperationException("Not supported");
        }

        @Override // ai.konduit.serving.pipeline.api.format.NDArrayConverter
        public <T> T convert(NDArray nDArray, Class<T> cls) {
            return (T) this.c2.convert(NDArray.create(this.c1.convert(nDArray, SerializedNDArray.class)), this.cTo);
        }

        public TwoStepNDArrayConverter(Class<?> cls, Class<?> cls2, NDArrayConverter nDArrayConverter, NDArrayConverter nDArrayConverter2) {
            this.cFrom = cls;
            this.cTo = cls2;
            this.c1 = nDArrayConverter;
            this.c2 = nDArrayConverter2;
        }
    }

    protected NDArrayConverterRegistry() {
        super(NDArrayConverter.class);
    }

    public static int numFactories() {
        return INSTANCE.registryNumFactories();
    }

    public static List<NDArrayConverter> getFactories() {
        return INSTANCE.registryGetFactories();
    }

    public static NDArrayConverter getFactoryFor(@NonNull Object obj) {
        if (obj == null) {
            throw new NullPointerException("o is marked non-null but is null");
        }
        return INSTANCE.registryGetFactoryFor(obj);
    }

    @Override // ai.konduit.serving.pipeline.registry.AbstractRegistry
    public boolean acceptFactory(NDArrayConverter nDArrayConverter, Object obj) {
        Pair pair = (Pair) obj;
        return nDArrayConverter.canConvert((NDArray) pair.getFirst(), (NDArrayFormat<?>) pair.getSecond());
    }

    @Override // ai.konduit.serving.pipeline.registry.AbstractRegistry
    public Set<Class<?>> supportedForFactory(NDArrayConverter nDArrayConverter) {
        return Collections.emptySet();
    }

    public static NDArrayConverter getConverterFor(NDArray nDArray, Class<?> cls) {
        return INSTANCE.getConverterForClass(nDArray, cls);
    }

    public static NDArrayConverter getConverterFor(NDArray nDArray, NDArrayFormat<?> nDArrayFormat) {
        return INSTANCE.getConverterForType(nDArray, nDArrayFormat);
    }

    public NDArrayConverter getConverterForClass(NDArray nDArray, Class<?> cls) {
        NDArrayConverter converterForClass;
        if (this.factories == null) {
            init();
        }
        if (this.factoriesMap.containsKey(cls)) {
            return (NDArrayConverter) ((List) this.factoriesMap.get(cls)).get(0);
        }
        for (T t : this.factories) {
            if (t.canConvert(nDArray, cls)) {
                return t;
            }
        }
        if (cls == SerializedNDArray.class || (nDArray.get() instanceof SerializedNDArray) || (converterForClass = getConverterForClass(nDArray, SerializedNDArray.class)) == null) {
            return null;
        }
        return new TwoStepNDArrayConverter(nDArray.get().getClass(), cls, converterForClass, getConverterForClass(NDArray.create(converterForClass.convert(nDArray, SerializedNDArray.class)), cls));
    }

    public NDArrayConverter getConverterForType(NDArray nDArray, NDArrayFormat<?> nDArrayFormat) {
        if (this.factories == null) {
            init();
        }
        for (T t : this.factories) {
            if (t.canConvert(nDArray, nDArrayFormat)) {
                return t;
            }
        }
        return null;
    }

    public static void addConverter(NDArrayConverter nDArrayConverter) {
        INSTANCE.addFactoryInstance(nDArrayConverter);
    }
}
