package org.linqs.psl.database;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.linqs.psl.database.atom.AtomCache;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.QueryAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Parallel;

/* loaded from: input_file:org/linqs/psl/database/Database.class */
public abstract class Database implements ReadableDatabase, WritableDatabase {
    private static final String THREAD_QUERY_ATOM_KEY = Database.class.getName() + "::" + QueryAtom.class.getName();
    protected final DataStore parentDataStore;
    protected final Partition writePartition;
    protected final int writeID;
    protected final List<Partition> readPartitions;
    protected final List<Integer> readIDs;
    protected final List<Integer> allPartitionIDs;
    protected final AtomCache cache;
    protected boolean closed;

    public Database(DataStore dataStore, Partition partition, Partition[] partitionArr) {
        this.parentDataStore = dataStore;
        this.writePartition = partition;
        this.writeID = partition.getID();
        this.readPartitions = Arrays.asList(partitionArr);
        this.readIDs = new ArrayList(partitionArr.length);
        for (Partition partition2 : partitionArr) {
            this.readIDs.add(Integer.valueOf(partition2.getID()));
        }
        if (this.readIDs.contains(new Integer(this.writeID))) {
            this.readIDs.remove(new Integer(this.writeID));
        }
        this.allPartitionIDs = new ArrayList(this.readIDs.size() + 1);
        this.allPartitionIDs.addAll(this.readIDs);
        this.allPartitionIDs.add(Integer.valueOf(this.writeID));
        this.cache = new AtomCache(this);
    }

    public abstract GroundAtom getAtom(StandardPredicate standardPredicate, boolean z, Constant... constantArr);

    @Override // org.linqs.psl.database.ReadableDatabase
    public boolean hasAtom(StandardPredicate standardPredicate, Constant... constantArr) {
        return getAtom(standardPredicate, false, constantArr) != null;
    }

    public boolean hasCachedAtom(StandardPredicate standardPredicate, Constant... constantArr) {
        QueryAtom queryAtom;
        if (Parallel.hasThreadObject(THREAD_QUERY_ATOM_KEY)) {
            queryAtom = (QueryAtom) Parallel.getThreadObject(THREAD_QUERY_ATOM_KEY);
            queryAtom.assume(standardPredicate, constantArr);
        } else {
            queryAtom = new QueryAtom(standardPredicate, constantArr);
            Parallel.putThreadObject(THREAD_QUERY_ATOM_KEY, queryAtom);
        }
        return hasCachedAtom(queryAtom);
    }

    public boolean hasCachedAtom(QueryAtom queryAtom) {
        return this.cache.getCachedAtom(queryAtom) != null;
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public int countAllGroundAtoms(StandardPredicate standardPredicate) {
        return countAllGroundAtoms(standardPredicate, this.allPartitionIDs);
    }

    public abstract int countAllGroundAtoms(StandardPredicate standardPredicate, List<Integer> list);

    @Override // org.linqs.psl.database.ReadableDatabase
    public int countAllGroundRandomVariableAtoms(StandardPredicate standardPredicate) {
        if (isClosed(standardPredicate)) {
            return 0;
        }
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(Integer.valueOf(this.writeID));
        return countAllGroundAtoms(standardPredicate, arrayList);
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public Iterable<GroundAtom> getAllCachedAtoms() {
        return this.cache.getCachedAtoms();
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public Iterable<RandomVariableAtom> getAllCachedRandomVariableAtoms() {
        return this.cache.getCachedRandomVariableAtoms();
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public List<GroundAtom> getAllGroundAtoms(StandardPredicate standardPredicate) {
        return getAllGroundAtoms(standardPredicate, this.allPartitionIDs);
    }

    public abstract List<GroundAtom> getAllGroundAtoms(StandardPredicate standardPredicate, List<Integer> list);

    @Override // org.linqs.psl.database.ReadableDatabase
    public List<RandomVariableAtom> getAllGroundRandomVariableAtoms(StandardPredicate standardPredicate) {
        if (isClosed(standardPredicate)) {
            return new ArrayList();
        }
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(Integer.valueOf(this.writeID));
        List<GroundAtom> allGroundAtoms = getAllGroundAtoms(standardPredicate, arrayList);
        ArrayList arrayList2 = new ArrayList(allGroundAtoms.size());
        for (GroundAtom groundAtom : allGroundAtoms) {
            if (groundAtom instanceof ObservedAtom) {
                throw new IllegalStateException(String.format("Found a ground atom (%s) that is both observed and a target. An atom can only be one at a time. Check your data files.", groundAtom));
            }
            arrayList2.add((RandomVariableAtom) groundAtom);
        }
        return arrayList2;
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public List<ObservedAtom> getAllGroundObservedAtoms(StandardPredicate standardPredicate) {
        if (this.readIDs.size() == 0) {
            return new ArrayList();
        }
        List<GroundAtom> allGroundAtoms = getAllGroundAtoms(standardPredicate, this.readIDs);
        ArrayList arrayList = new ArrayList(allGroundAtoms.size());
        for (GroundAtom groundAtom : allGroundAtoms) {
            if (groundAtom instanceof RandomVariableAtom) {
                throw new IllegalStateException(String.format("Found a ground atom (%s) that is both observed and a target. An atom can only be one at a time. Check your data files.", groundAtom));
            }
            arrayList.add((ObservedAtom) groundAtom);
        }
        return arrayList;
    }

    @Override // org.linqs.psl.database.WritableDatabase
    public void commit(RandomVariableAtom randomVariableAtom) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(randomVariableAtom);
        commit(arrayList);
    }

    @Override // org.linqs.psl.database.WritableDatabase
    public void commitCachedAtoms() {
        commitCachedAtoms(false);
    }

    @Override // org.linqs.psl.database.WritableDatabase
    public void commitCachedAtoms(boolean z) {
        if (z) {
            commit(IteratorUtils.filter(getAllCachedRandomVariableAtoms(), new IteratorUtils.FilterFunction<RandomVariableAtom>() { // from class: org.linqs.psl.database.Database.1
                @Override // org.linqs.psl.util.IteratorUtils.FilterFunction
                public boolean keep(RandomVariableAtom randomVariableAtom) {
                    return randomVariableAtom.getPersisted();
                }
            }));
        } else {
            commit(getAllCachedRandomVariableAtoms());
        }
    }

    public DataStore getDataStore() {
        return this.parentDataStore;
    }

    public List<Partition> getReadPartitions() {
        return Collections.unmodifiableList(this.readPartitions);
    }

    public Partition getWritePartition() {
        return this.writePartition;
    }

    @Override // org.linqs.psl.database.ReadableDatabase
    public int getCachedRVACount() {
        return this.cache.getRVACount();
    }
}
