package org.linqs.psl.reasoner.term.streaming;

import java.io.File;
import java.nio.ByteBuffer;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.term.HyperplaneTermGenerator;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.SystemUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/reasoner/term/streaming/StreamingTermStore.class */
public abstract class StreamingTermStore<T extends ReasonerTerm> implements VariableTermStore<T, RandomVariableAtom> {
    public static final String CONFIG_PREFIX = "streamingtermstore";
    public static final String PAGE_LOCATION_KEY = "streamingtermstore.pagelocation";
    public static final String PAGE_SIZE_KEY = "streamingtermstore.pagesize";
    public static final int PAGE_SIZE_DEFAULT = 10000;
    public static final String SHUFFLE_PAGE_KEY = "streamingtermstore.shufflepage";
    public static final boolean SHUFFLE_PAGE_DEFAULT = true;
    public static final String RANDOMIZE_PAGE_ACCESS_KEY = "streamingtermstore.randomizepageaccess";
    public static final boolean RANDOMIZE_PAGE_ACCESS_DEFAULT = true;
    public static final String WARN_RULES_KEY = "streamingtermstore.warnunsupportedrules";
    public static final boolean WARN_RULES_DEFAULT = true;
    public static final int INITIAL_PATH_CACHE_SIZE = 100;
    protected AtomManager atomManager;
    protected Map<RandomVariableAtom, Integer> variables;
    private float[] variableValues;
    private RandomVariableAtom[] variableAtoms;
    protected List<String> termPagePaths;
    protected List<String> volatilePagePaths;
    protected boolean initialRound;
    protected StreamingIterator<T> activeIterator;
    protected int seenTermCount;
    protected int numPages;
    protected HyperplaneTermGenerator<T, RandomVariableAtom> termGenerator;
    protected ByteBuffer termBuffer;
    protected ByteBuffer volatileBuffer;
    protected List<T> termCache;
    protected List<T> termPool;
    protected int[] shuffleMap;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) StreamingTermStore.class);
    public static final String PAGE_LOCATION_DEFAULT = SystemUtils.getTempDir("streaimg_term_cache_pages");
    protected int pageSize = Config.getInt(PAGE_SIZE_KEY, 10000);
    protected String pageDir = Config.getString(PAGE_LOCATION_KEY, PAGE_LOCATION_DEFAULT);
    protected boolean shufflePage = Config.getBoolean(SHUFFLE_PAGE_KEY, true);
    protected boolean randomizePageAccess = Config.getBoolean(RANDOMIZE_PAGE_ACCESS_KEY, true);
    protected boolean warnRules = Config.getBoolean(WARN_RULES_KEY, true);
    protected List<WeightedRule> rules = new ArrayList();

    public StreamingTermStore(List<Rule> list, AtomManager atomManager, HyperplaneTermGenerator<T, RandomVariableAtom> hyperplaneTermGenerator) {
        for (Rule rule : list) {
            if (rule.isWeighted()) {
                if (((WeightedRule) rule).getWeight() < 0.0d) {
                    if (this.warnRules) {
                        log.warn("Streaming term stores do not support negative weights: " + rule);
                    }
                } else if (rule.supportsIndividualGrounding()) {
                    if (supportsRule(rule)) {
                        this.rules.add((WeightedRule) rule);
                    } else if (this.warnRules) {
                        log.warn("Rule not supported: " + rule);
                    }
                } else if (this.warnRules) {
                    log.warn("Streaming term stores do not support rules that cannot individually ground (arithmetic rules with summations): " + rule);
                }
            } else if (this.warnRules) {
                log.warn("Streaming term stores do not support hard constraints: " + rule);
            }
        }
        if (list.size() == 0) {
            throw new IllegalArgumentException("Found no valid rules for a streaming term store.");
        }
        this.atomManager = atomManager;
        this.termGenerator = hyperplaneTermGenerator;
        ensureVariableCapacity(atomManager.getCachedRVACount());
        this.termPagePaths = new ArrayList(100);
        this.volatilePagePaths = new ArrayList(100);
        this.initialRound = true;
        this.activeIterator = null;
        this.numPages = 0;
        this.termBuffer = null;
        this.volatileBuffer = null;
        SystemUtils.recursiveDelete(this.pageDir);
        if (this.pageSize <= 1) {
            throw new IllegalArgumentException("Page size is too small.");
        }
        this.termCache = new ArrayList(this.pageSize);
        this.termPool = new ArrayList(this.pageSize);
        this.shuffleMap = new int[this.pageSize];
        new File(this.pageDir).mkdirs();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public boolean isLoaded() {
        return !this.initialRound;
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getNumVariables() {
        return this.variables.size();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public Iterable<RandomVariableAtom> getVariables() {
        return this.variables.keySet();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public float[] getVariableValues() {
        return this.variableValues;
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public int getVariableIndex(RandomVariableAtom randomVariableAtom) {
        return this.variables.get(randomVariableAtom).intValue();
    }

    @Override // org.linqs.psl.reasoner.term.VariableTermStore
    public void syncAtoms() {
        for (int i = 0; i < this.variables.size(); i++) {
            this.variableAtoms[i].setValue(this.variableValues[i]);
        }
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public synchronized RandomVariableAtom createLocalVariable(RandomVariableAtom randomVariableAtom) {
        if (this.variables.containsKey(randomVariableAtom)) {
            return randomVariableAtom;
        }
        if (this.variables.size() >= this.variableAtoms.length) {
            ensureVariableCapacity(this.variables.size() * 2);
        }
        int size = this.variables.size();
        this.variables.put(randomVariableAtom, Integer.valueOf(size));
        this.variableValues[size] = RandUtils.nextFloat();
        this.variableAtoms[size] = randomVariableAtom;
        return randomVariableAtom;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void ensureVariableCapacity(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("Variable capacity must be non-negative. Got: " + i);
        }
        if (this.variables == null || this.variables.size() == 0) {
            this.variables = new HashMap((int) Math.ceil(i / 0.75d));
            this.variableValues = new float[i];
            this.variableAtoms = new RandomVariableAtom[i];
        } else if (this.variables.size() < i) {
            if (i < this.variables.size() * 2) {
                i = this.variables.size() * 2;
            }
            HashMap hashMap = new HashMap((int) Math.ceil(i / 0.75d));
            hashMap.putAll(this.variables);
            this.variables = hashMap;
            this.variableValues = Arrays.copyOf(this.variableValues, i);
            this.variableAtoms = (RandomVariableAtom[]) Arrays.copyOf(this.variableAtoms, i);
        }
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public int size() {
        return this.seenTermCount;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void add(GroundRule groundRule, T t) {
        throw new UnsupportedOperationException();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public T get(int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void ensureCapacity(int i) {
        throw new UnsupportedOperationException();
    }

    public String getTermPagePath(int i) {
        for (int size = this.termPagePaths.size(); size <= i; size++) {
            this.termPagePaths.add(Paths.get(this.pageDir, String.format("%08d_term.page", Integer.valueOf(size))).toString());
        }
        return this.termPagePaths.get(i);
    }

    public String getVolatilePagePath(int i) {
        for (int size = this.volatilePagePaths.size(); size <= i; size++) {
            this.volatilePagePaths.add(Paths.get(this.pageDir, String.format("%08d_volatile.page", Integer.valueOf(size))).toString());
        }
        return this.volatilePagePaths.get(i);
    }

    public void initialIterationComplete(int i, int i2, ByteBuffer byteBuffer, ByteBuffer byteBuffer2) {
        this.seenTermCount = i;
        this.numPages = i2;
        this.termBuffer = byteBuffer;
        this.volatileBuffer = byteBuffer2;
        this.initialRound = false;
        this.activeIterator = null;
    }

    public void cacheIterationComplete() {
        this.activeIterator = null;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public Iterator<T> noWriteIterator() {
        if (this.activeIterator != null) {
            throw new IllegalStateException("Iterator already exists for this StreamingTermStore. Exhaust the iterator first.");
        }
        if (this.initialRound) {
            throw new IllegalStateException("A full iteration must have already been completed before asking for a read-only iterator.");
        }
        this.activeIterator = getNoWriteIterator();
        return this.activeIterator;
    }

    @Override // java.lang.Iterable
    public Iterator<T> iterator() {
        if (this.activeIterator != null) {
            throw new IllegalStateException("Iterator already exists for this StreamingTermStore. Exhaust the iterator first.");
        }
        if (this.initialRound) {
            this.activeIterator = getInitialRoundIterator();
        } else {
            this.activeIterator = getCacheIterator();
        }
        return this.activeIterator;
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void clear() {
        this.initialRound = true;
        this.numPages = 0;
        if (this.activeIterator != null) {
            this.activeIterator.close();
            this.activeIterator = null;
        }
        if (this.variables != null) {
            this.variables.clear();
        }
        if (this.termCache != null) {
            this.termCache.clear();
        }
        if (this.termPool != null) {
            this.termPool.clear();
        }
        SystemUtils.recursiveDelete(this.pageDir);
    }

    @Override // org.linqs.psl.reasoner.term.TermStore
    public void close() {
        clear();
        if (this.variables != null) {
            this.variables = null;
        }
        if (this.termBuffer != null) {
            this.termBuffer.clear();
            this.termBuffer = null;
        }
        if (this.volatileBuffer != null) {
            this.volatileBuffer.clear();
            this.volatileBuffer = null;
        }
        if (this.termCache != null) {
            this.termCache = null;
        }
        if (this.termPool != null) {
            this.termPool = null;
        }
    }

    protected abstract boolean supportsRule(Rule rule);

    protected abstract StreamingIterator<T> getInitialRoundIterator();

    protected abstract StreamingIterator<T> getCacheIterator();

    protected abstract StreamingIterator<T> getNoWriteIterator();
}
