package org.neo4j.graphalgo.core.model;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.jetbrains.annotations.Nullable;
import org.neo4j.graphalgo.config.BaseConfig;
import org.neo4j.graphalgo.config.ModelConfig;
import org.neo4j.graphalgo.core.GdsEdition;
import org.neo4j.graphalgo.core.StringSimilarity;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/graphalgo/core/model/ModelCatalog.class */
public final class ModelCatalog {
    private static final Map<String, UserCatalog> userCatalogs = new ConcurrentHashMap();
    private static final UserCatalog publicModels = new UserCatalog();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/graphalgo/core/model/ModelCatalog$UserCatalog.class */
    public static class UserCatalog {
        private static final long ALLOWED_MODELS_COUNT = 1;
        private static final UserCatalog EMPTY = new UserCatalog();
        private final Map<String, Model<?, ?>> userModels = new ConcurrentHashMap();

        UserCatalog() {
        }

        public void set(Model<?, ?> model) {
            canStoreModel(model.algoType());
            if (exists(model.name())) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Model with name `%s` already exists", new Object[]{model.name()}));
            }
            this.userModels.put(model.name(), model);
        }

        public void setUnsafe(Model<?, ?> model) {
            this.userModels.put(model.name(), model);
        }

        public <D, C extends ModelConfig & BaseConfig> Model<D, C> get(String str, Class<D> cls, Class<C> cls2) {
            return get(getUntyped(str), cls, cls2);
        }

        public <D, C extends ModelConfig & BaseConfig> Model<D, C> getChecked(String str, Class<D> cls, Class<C> cls2) {
            return get(getUntypedChecked(str), cls, cls2);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private <D, C extends ModelConfig & BaseConfig> Model<D, C> get(Model<?, ?> model, Class<D> cls, Class<C> cls2) {
            if (model != 0) {
                Object data = model.data();
                String name = model.name();
                if (!cls.isInstance(data)) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale("The model `%s` has data with different types than expected. Expected data type: `%s`, invoked with model data type: `%s`.", new Object[]{name, data.getClass().getName(), cls.getName()}));
                }
                ModelConfig trainConfig = model.trainConfig();
                if (!cls2.isInstance(trainConfig)) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale("The model `%s` has a training config with different types than expected. Expected train config type: `%s`, invoked with model config type: `%s`.", new Object[]{name, trainConfig.getClass().getName(), cls2.getName()}));
                }
            }
            return model;
        }

        public boolean exists(String str) {
            return this.userModels.containsKey(str);
        }

        public Optional<String> type(String str) {
            return Optional.ofNullable(this.userModels.get(str)).map((v0) -> {
                return v0.algoType();
            });
        }

        public Model<?, ?> drop(String str) {
            return this.userModels.remove(getUntypedChecked(str).name());
        }

        public Collection<Model<?, ?>> list() {
            return this.userModels.values();
        }

        public Model<?, ?> list(String str) {
            return getUntyped(str);
        }

        public void removeAllLoadedModels() {
            this.userModels.clear();
        }

        public UserCatalog join(UserCatalog userCatalog) {
            this.userModels.putAll(userCatalog.userModels);
            return this;
        }

        private boolean reachedModelsLimit(String str) {
            return modelsPerType(str) == ALLOWED_MODELS_COUNT;
        }

        private long modelsPerType(String str) {
            return this.userModels.values().stream().filter(model -> {
                return model.algoType().equals(str);
            }).count();
        }

        private void canStoreModel(String str) {
            if (GdsEdition.instance().isOnCommunityEdition() && reachedModelsLimit(str)) {
                throw new IllegalArgumentException("Community users can only store one model in the catalog");
            }
        }

        private Model<?, ?> getUntyped(String str) {
            return this.userModels.get(str);
        }

        private Model<?, ?> getUntypedChecked(String str) {
            Model<?, ?> model = this.userModels.get(str);
            if (model == null) {
                throw new NoSuchElementException(StringSimilarity.prettySuggestions(StringFormatting.formatWithLocale("Model with name `%s` does not exist.", new Object[]{str}), str, this.userModels.keySet()));
            }
            return model;
        }
    }

    private ModelCatalog() {
    }

    public static void set(Model<?, ?> model) {
        userCatalogs.compute(model.creator(), (str, userCatalog) -> {
            if (userCatalog == null) {
                userCatalog = new UserCatalog();
            }
            userCatalog.set(model);
            return userCatalog;
        });
    }

    public static void setUnsafe(Model<?, ?> model) {
        userCatalogs.compute(model.creator(), (str, userCatalog) -> {
            if (userCatalog == null) {
                userCatalog = new UserCatalog();
            }
            userCatalog.setUnsafe(model);
            return userCatalog;
        });
    }

    public static <D, C extends ModelConfig & BaseConfig> Model<D, C> get(String str, String str2, Class<D> cls, Class<C> cls2) {
        UserCatalog userCatalog = getUserCatalog(str);
        Model<D, C> model = userCatalog.get(str2, cls, cls2);
        if (model != null) {
            return model;
        }
        Model<D, C> model2 = publicModels.get(str2, cls, cls2);
        if (model2 != null) {
            return model2;
        }
        throw new NoSuchElementException(StringSimilarity.prettySuggestions(StringFormatting.formatWithLocale("Model with name `%s` does not exist.", new Object[]{str2}), str2, userCatalog.userModels.keySet()));
    }

    @Nullable
    public static Model<?, ?> getUntyped(String str, String str2) {
        return getUntyped(str, str2, true);
    }

    @Nullable
    public static Model<?, ?> getUntyped(String str, String str2, boolean z) {
        UserCatalog userCatalog = getUserCatalog(str);
        Model<?, ?> untyped = userCatalog.getUntyped(str2);
        if (untyped == null) {
            untyped = publicModels.getUntyped(str2);
        }
        if (untyped == null && z) {
            throw new NoSuchElementException(StringSimilarity.prettySuggestions(StringFormatting.formatWithLocale("Model with name `%s` does not exist.", new Object[]{str2}), str2, userCatalog.userModels.keySet()));
        }
        return untyped;
    }

    public static boolean exists(String str, String str2) {
        return getUserCatalog(str).exists(str2) || publicModels.exists(str2);
    }

    public static Optional<String> type(String str, String str2) {
        return getUserCatalog(str).type(str2);
    }

    public static Model<?, ?> drop(String str, String str2) {
        if (!publicModels.exists(str2)) {
            return getUserCatalog(str).drop(str2);
        }
        if (publicModels.getUntyped(str2).creator().equals(str)) {
            return publicModels.drop(str2);
        }
        throw new IllegalStateException(StringFormatting.formatWithLocale("Only the creator of model %s can drop it.", new Object[]{str2}));
    }

    public static Collection<Model<?, ?>> list(String str) {
        ArrayList arrayList = new ArrayList(getUserCatalog(str).list());
        arrayList.addAll(publicModels.list());
        return arrayList;
    }

    @Nullable
    public static Model<?, ?> list(String str, String str2) {
        return getUntyped(str, str2, false);
    }

    public static Model<?, ?> publish(String str, String str2) {
        if (GdsEdition.instance().isOnCommunityEdition()) {
            throw new IllegalArgumentException("Publishing a model is only available with the Graph Data Science library Enterprise Edition.");
        }
        Model<?, ?> untyped = getUntyped(str, str2);
        if (untyped.sharedWith().contains("*")) {
            return untyped;
        }
        Model<?, ?> publish = untyped.publish();
        publicModels.set(publish);
        drop(str, str2);
        return publish;
    }

    public static void removeAllLoadedModels() {
        userCatalogs.clear();
        publicModels.removeAllLoadedModels();
    }

    public static void checkStorable(String str, String str2) {
        getUserCatalog(str).canStoreModel(str2);
    }

    private static UserCatalog getUserCatalog(String str) {
        return userCatalogs.getOrDefault(str, UserCatalog.EMPTY);
    }
}
