package com.feedzai.fos.impl.weka;

import com.feedzai.fos.api.FOSException;
import com.feedzai.fos.api.Scorer;
import com.feedzai.fos.common.validation.NotNull;
import com.feedzai.fos.impl.weka.config.WekaManagerConfig;
import com.feedzai.fos.impl.weka.config.WekaModelConfig;
import com.feedzai.fos.impl.weka.utils.AsyncScoringTask;
import com.feedzai.fos.impl.weka.utils.WekaThreadSafeScorer;
import com.feedzai.fos.impl.weka.utils.WekaThreadSafeScorerPassthrough;
import com.feedzai.fos.impl.weka.utils.WekaThreadSafeScorerPool;
import com.feedzai.fos.impl.weka.utils.WekaUtils;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/feedzai/fos/impl/weka/WekaScorer.class */
public class WekaScorer implements Scorer {
    private static final Logger logger = LoggerFactory.getLogger(WekaScorer.class);
    private ExecutorService executorService;
    private WekaManagerConfig wekaManagerConfig;
    private Map<UUID, WekaThreadSafeScorer> wekaThreadSafeScorers = new HashMap();
    private ReentrantReadWriteLock reloadModelsLock = new ReentrantReadWriteLock(true);

    private WekaThreadSafeScorer getScorer(UUID uuid) throws FOSException {
        WekaThreadSafeScorer wekaThreadSafeScorer = this.wekaThreadSafeScorers.get(uuid);
        if (wekaThreadSafeScorer != null) {
            return wekaThreadSafeScorer;
        }
        logger.error("No model with ID '{}'", uuid);
        throw new FOSException("No model with ID " + uuid);
    }

    private <T> T getFuture(Future<T> future, UUID uuid) throws FOSException {
        try {
            return future.get();
        } catch (InterruptedException e) {
            logger.error("Could not score on model '{}'", uuid, e);
            throw new FOSException(e);
        } catch (ExecutionException e2) {
            logger.error("Could not score on model '{}'", uuid, e2);
            throw new FOSException(e2);
        }
    }

    public WekaScorer(Map<UUID, WekaModelConfig> map, WekaManagerConfig wekaManagerConfig) {
        Preconditions.checkNotNull(map, "Model configuration map cannot be null");
        Preconditions.checkNotNull(wekaManagerConfig, "Manager config cannot be null");
        this.wekaManagerConfig = wekaManagerConfig;
        for (Map.Entry<UUID, WekaModelConfig> entry : map.entrySet()) {
            try {
                if (entry.getValue().isClassifierThreadSafe()) {
                    this.wekaThreadSafeScorers.put(entry.getValue().getId(), new WekaThreadSafeScorerPassthrough(entry.getValue(), wekaManagerConfig));
                } else {
                    this.wekaThreadSafeScorers.put(entry.getValue().getId(), new WekaThreadSafeScorerPool(entry.getValue(), wekaManagerConfig));
                }
            } catch (Exception e) {
                logger.error("Could not load from '{}' (continuing to load others)", entry.getKey(), e);
            }
        }
        this.executorService = Executors.newFixedThreadPool(wekaManagerConfig.getThreadPoolSize());
    }

    public void close() {
        try {
            this.reloadModelsLock.writeLock().lock();
            this.executorService.shutdown();
            for (WekaThreadSafeScorer wekaThreadSafeScorer : this.wekaThreadSafeScorers.values()) {
                if (wekaThreadSafeScorer != null) {
                    wekaThreadSafeScorer.close();
                }
            }
        } finally {
            this.reloadModelsLock.writeLock().unlock();
        }
    }

    @NotNull
    public List<double[]> score(List<UUID> list, Object[] objArr) throws FOSException {
        Preconditions.checkNotNull(list, "Models to score cannot be null");
        Preconditions.checkNotNull(objArr, "Instance cannot be null");
        ArrayList arrayList = new ArrayList(list.size());
        try {
            this.reloadModelsLock.readLock().lock();
            if (list.size() == 1) {
                arrayList.add(getScorer(list.get(0)).score(objArr));
            } else {
                HashMap hashMap = new HashMap(list.size());
                for (UUID uuid : list) {
                    hashMap.put(uuid, this.executorService.submit(new AsyncScoringTask(getScorer(uuid), objArr)));
                }
                for (UUID uuid2 : list) {
                    arrayList.add(getFuture((Future) hashMap.get(uuid2), uuid2));
                }
            }
            return arrayList;
        } finally {
            this.reloadModelsLock.readLock().unlock();
        }
    }

    @NotNull
    public Map<UUID, double[]> score(Map<UUID, Object[]> map) throws FOSException {
        Preconditions.checkNotNull(map, "Map of instances cannot be null");
        HashMap hashMap = new HashMap(map.size());
        try {
            this.reloadModelsLock.readLock().lock();
            if (map.size() == 1) {
                for (Map.Entry<UUID, Object[]> entry : map.entrySet()) {
                    hashMap.put(entry.getKey(), getScorer(entry.getKey()).score(entry.getValue()));
                }
            } else {
                HashMap hashMap2 = new HashMap(map.size());
                for (Map.Entry<UUID, Object[]> entry2 : map.entrySet()) {
                    hashMap2.put(entry2.getKey(), this.executorService.submit(new AsyncScoringTask(getScorer(entry2.getKey()), entry2.getValue())));
                }
                for (Map.Entry entry3 : hashMap2.entrySet()) {
                    hashMap.put(entry3.getKey(), getFuture((Future) entry3.getValue(), (UUID) entry3.getKey()));
                }
            }
            return hashMap;
        } finally {
            this.reloadModelsLock.readLock().unlock();
        }
    }

    @NotNull
    public List<double[]> score(UUID uuid, List<Object[]> list) throws FOSException {
        Preconditions.checkNotNull(list, "List of scorables cannot be null");
        ArrayList arrayList = new ArrayList(list.size());
        try {
            this.reloadModelsLock.readLock().lock();
            if (list.size() == 1) {
                arrayList.add(getScorer(uuid).score(list.get(0)));
            } else {
                HashMap hashMap = new HashMap(list.size());
                for (Object[] objArr : list) {
                    hashMap.put(objArr, this.executorService.submit(new AsyncScoringTask(this.wekaThreadSafeScorers.get(uuid), objArr)));
                }
                Iterator<Object[]> it = list.iterator();
                while (it.hasNext()) {
                    arrayList.add(getFuture((Future) hashMap.get(it.next()), uuid));
                }
            }
            return arrayList;
        } finally {
            this.reloadModelsLock.readLock().unlock();
        }
    }

    public void addOrUpdate(WekaModelConfig wekaModelConfig) throws FOSException {
        Preconditions.checkNotNull(wekaModelConfig, "Model config cannot be null");
        WekaUtils.closeSilently(quickSwitch(wekaModelConfig.getId(), new WekaThreadSafeScorerPool(wekaModelConfig, this.wekaManagerConfig)));
    }

    public void removeModel(UUID uuid) {
        WekaUtils.closeSilently(quickSwitch(uuid, null));
    }

    private WekaThreadSafeScorer quickSwitch(UUID uuid, WekaThreadSafeScorer wekaThreadSafeScorer) {
        try {
            this.reloadModelsLock.writeLock().lock();
            WekaThreadSafeScorer put = this.wekaThreadSafeScorers.put(uuid, wekaThreadSafeScorer);
            this.reloadModelsLock.writeLock().unlock();
            return put;
        } catch (Throwable th) {
            this.reloadModelsLock.writeLock().unlock();
            throw th;
        }
    }
}
