package org.icij.datashare.text.nlp;

import java.io.IOException;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.icij.datashare.DynamicClassLoader;
import org.icij.datashare.io.RemoteFiles;
import org.icij.datashare.text.Language;
import org.icij.datashare.text.nlp.Pipeline;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/icij/datashare/text/nlp/AbstractModels.class */
public abstract class AbstractModels<T> {
    public static final String JVM_PROPERTY_NAME = "DS_SYNC_NLP_MODELS";
    private static final Path BASE_DIR = Paths.get(".", new String[0]).toAbsolutePath().normalize();
    protected static final Path BASE_CLASSPATH = Paths.get("models", new String[0]);
    private static final String PREFIX = "dist";
    public final NlpStage stage;
    protected final Pipeline.Type type;
    protected final Logger LOGGER = LoggerFactory.getLogger(getClass());
    protected final ConcurrentHashMap<Language, Semaphore> modelLock = new ConcurrentHashMap<Language, Semaphore>() { // from class: org.icij.datashare.text.nlp.AbstractModels.1
        {
            for (Language language : Language.values()) {
                put(language, new Semaphore(1, true));
            }
        }
    };
    protected final Map<Language, T> models = new HashMap();

    protected AbstractModels(Pipeline.Type type, NlpStage nlpStage) {
        this.stage = nlpStage;
        this.type = type;
    }

    protected abstract T loadModelFile(Language language) throws IOException;

    protected abstract String getVersion();

    public T get(Language language) throws InterruptedException {
        if (!isLoaded(language)) {
            load(language);
        }
        return this.models.get(language);
    }

    private void load(Language language) throws InterruptedException {
        Semaphore semaphore = this.modelLock.get(language);
        semaphore.acquire();
        try {
            try {
                if (isLoaded(language)) {
                    semaphore.release();
                    return;
                }
                if (isSync()) {
                    downloadIfNecessary(language);
                }
                this.models.put(language, loadModelFile(language));
                this.LOGGER.info("loaded {} model for {}", this.stage, language);
                semaphore.release();
            } catch (IOException e) {
                this.LOGGER.error("failed loading " + this.stage, e);
                semaphore.release();
            }
        } catch (Throwable th) {
            semaphore.release();
            throw th;
        }
    }

    public Path getModelsBasePath(Language language) {
        return BASE_CLASSPATH.resolve(this.type.name().toLowerCase()).resolve(getVersion().replace('.', '-')).resolve(language.iso6391Code());
    }

    public Path getModelsFilesystemPath(Language language) {
        return Paths.get(PREFIX, new String[0]).resolve(getModelsBasePath(language));
    }

    public void addResourceToContextClassLoader(Path path) {
        DynamicClassLoader dynamicClassLoader = (DynamicClassLoader) ClassLoader.getSystemClassLoader();
        URL resource = dynamicClassLoader.getResource(path.toString());
        this.LOGGER.info("adding {} to system classloader", resource == null ? null : resource.getPath());
        dynamicClassLoader.add(resource);
    }

    protected boolean isPresent(Language language) {
        return Thread.currentThread().getContextClassLoader().getResource(getModelsBasePath(language).toString()) != null;
    }

    protected void downloadIfNecessary(Language language) {
        String replace = getModelsFilesystemPath(language).toString().replace("\\", "/");
        RemoteFiles remoteFiles = getRemoteFiles();
        try {
            try {
                if (isPresent(language) && remoteFiles.isSync(replace, BASE_DIR.toFile())) {
                    remoteFiles.shutdown();
                    return;
                }
                this.LOGGER.info("downloading models for language {} under {}", language, replace);
                remoteFiles.download(replace, BASE_DIR.toFile());
                this.LOGGER.info("models successfully downloaded for language {}", language);
                remoteFiles.shutdown();
            } catch (IOException | InterruptedException e) {
                this.LOGGER.error("failed downloading models for " + language, e);
                remoteFiles.shutdown();
            }
        } catch (Throwable th) {
            remoteFiles.shutdown();
            throw th;
        }
    }

    public void unload(Language language) throws InterruptedException {
        Semaphore semaphore = this.modelLock.get(language);
        semaphore.acquire();
        try {
            this.models.remove(language);
        } finally {
            semaphore.release();
        }
    }

    public static void syncModels(boolean z) {
        LoggerFactory.getLogger(AbstractModels.class).info("synchronize models is set to {}", Boolean.valueOf(z));
        System.setProperty(JVM_PROPERTY_NAME, String.valueOf(z));
    }

    public static boolean isSync() {
        return Boolean.parseBoolean(System.getProperty(JVM_PROPERTY_NAME, "true"));
    }

    public boolean isLoaded(Language language) {
        return this.models.containsKey(language);
    }

    protected RemoteFiles getRemoteFiles() {
        return RemoteFiles.getDefault();
    }
}
