package org.linqs.psl.application.learning.weight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/TrainingMap.class */
public class TrainingMap {
    private final Map<RandomVariableAtom, ObservedAtom> labelMap;
    private final Map<ObservedAtom, ObservedAtom> observedMap;
    private final List<RandomVariableAtom> latentVariables;
    private final List<ObservedAtom> missingLabels;
    private final List<ObservedAtom> missingTargets;

    public TrainingMap(PersistedAtomManager persistedAtomManager, Database database) {
        HashMap hashMap = new HashMap(persistedAtomManager.getPersistedCount());
        HashMap hashMap2 = new HashMap();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        HashSet hashSet = new HashSet();
        prefetchTruthAtoms(database);
        for (GroundAtom groundAtom : persistedAtomManager.getDatabase().getAllCachedAtoms()) {
            GroundAtom atom = database.hasCachedAtom((StandardPredicate) groundAtom.getPredicate(), groundAtom.getArguments()) ? database.getAtom((StandardPredicate) groundAtom.getPredicate(), false, groundAtom.getArguments()) : null;
            if (atom == null || (atom instanceof ObservedAtom)) {
                if (groundAtom instanceof RandomVariableAtom) {
                    if (atom == null) {
                        arrayList.add((RandomVariableAtom) groundAtom);
                    } else {
                        hashSet.add((ObservedAtom) atom);
                        hashMap.put((RandomVariableAtom) groundAtom, (ObservedAtom) atom);
                    }
                } else if (atom == null) {
                    arrayList2.add((ObservedAtom) groundAtom);
                } else {
                    hashSet.add((ObservedAtom) atom);
                    hashMap2.put((ObservedAtom) groundAtom, (ObservedAtom) atom);
                }
            }
        }
        for (GroundAtom groundAtom2 : database.getAllCachedAtoms()) {
            if ((groundAtom2 instanceof ObservedAtom) && !hashSet.contains(groundAtom2)) {
                if (persistedAtomManager.getDatabase().hasAtom((StandardPredicate) groundAtom2.getPredicate(), groundAtom2.getArguments())) {
                    throw new IllegalStateException("Un-persisted target atom: " + groundAtom2);
                }
                arrayList3.add((ObservedAtom) groundAtom2);
            }
        }
        this.labelMap = Collections.unmodifiableMap(hashMap);
        this.observedMap = Collections.unmodifiableMap(hashMap2);
        this.latentVariables = Collections.unmodifiableList(arrayList);
        this.missingLabels = Collections.unmodifiableList(arrayList2);
        this.missingTargets = Collections.unmodifiableList(arrayList3);
    }

    public Map<RandomVariableAtom, ObservedAtom> getLabelMap() {
        return this.labelMap;
    }

    public Map<ObservedAtom, ObservedAtom> getObservedMap() {
        return this.observedMap;
    }

    public List<RandomVariableAtom> getLatentVariables() {
        return this.latentVariables;
    }

    public List<ObservedAtom> getMissingLabels() {
        return this.missingLabels;
    }

    public List<ObservedAtom> getMissingTargets() {
        return this.missingTargets;
    }

    public Iterable<RandomVariableAtom> getAllPredictions() {
        return IteratorUtils.join(this.labelMap.keySet(), this.latentVariables);
    }

    public Iterable<GroundAtom> getAllTargets() {
        return IteratorUtils.join(this.labelMap.keySet(), this.observedMap.keySet(), this.latentVariables, this.missingLabels);
    }

    public Iterable<GroundAtom> getAllTruths() {
        return IteratorUtils.join(this.labelMap.values(), this.observedMap.values(), this.missingTargets);
    }

    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getFullMap() {
        return IteratorUtils.join(this.labelMap.entrySet(), this.observedMap.entrySet());
    }

    public String toString() {
        return String.format("Training Map -- Label Map: %d, Observed Map: %d, Latent Variables: %d, Missing Labels: %d, Missing Targets: %d", Integer.valueOf(this.labelMap.size()), Integer.valueOf(this.observedMap.size()), Integer.valueOf(this.latentVariables.size()), Integer.valueOf(this.missingLabels.size()), Integer.valueOf(this.missingTargets.size()));
    }

    private void prefetchTruthAtoms(Database database) {
        Iterator<StandardPredicate> it = database.getDataStore().getRegisteredPredicates().iterator();
        while (it.hasNext()) {
            database.getAllGroundAtoms(it.next());
        }
    }
}
