package org.linqs.psl.application.learning.weight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.FunctionalPredicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/TrainingMap.class */
public class TrainingMap {
    private static final Logger log = Logger.getLogger(TrainingMap.class);
    private Map<RandomVariableAtom, ObservedAtom> labelMap;
    private Map<ObservedAtom, ObservedAtom> observedMap = new HashMap();
    private List<RandomVariableAtom> latentVariables = new ArrayList();
    private List<ObservedAtom> missingLabels = new ArrayList();
    private List<ObservedAtom> missingTargets = new ArrayList();

    public TrainingMap(PersistedAtomManager persistedAtomManager, Database database) {
        this.labelMap = new HashMap(persistedAtomManager.getPersistedCount());
        HashSet hashSet = new HashSet();
        prefetchTruthAtoms(database);
        for (GroundAtom groundAtom : persistedAtomManager.getDatabase().getAllCachedAtoms()) {
            if (!(groundAtom.getPredicate() instanceof FunctionalPredicate)) {
                GroundAtom atom = database.hasCachedAtom((StandardPredicate) groundAtom.getPredicate(), groundAtom.getArguments()) ? database.getAtom((StandardPredicate) groundAtom.getPredicate(), false, false, -1.0d, groundAtom.getArguments()) : null;
                if (atom == null || (atom instanceof ObservedAtom)) {
                    if (groundAtom instanceof RandomVariableAtom) {
                        if (atom == null) {
                            this.latentVariables.add((RandomVariableAtom) groundAtom);
                        } else {
                            hashSet.add((ObservedAtom) atom);
                            this.labelMap.put((RandomVariableAtom) groundAtom, (ObservedAtom) atom);
                        }
                    } else if (atom == null) {
                        this.missingLabels.add((ObservedAtom) groundAtom);
                    } else {
                        hashSet.add((ObservedAtom) atom);
                        this.observedMap.put((ObservedAtom) groundAtom, (ObservedAtom) atom);
                    }
                }
            }
        }
        for (GroundAtom groundAtom2 : database.getAllCachedAtoms()) {
            if ((groundAtom2 instanceof ObservedAtom) && !hashSet.contains(groundAtom2)) {
                if (persistedAtomManager.getDatabase().hasAtom((StandardPredicate) groundAtom2.getPredicate(), groundAtom2.getArguments())) {
                    throw new IllegalStateException("Un-persisted target atom: " + groundAtom2);
                }
                this.missingTargets.add((ObservedAtom) groundAtom2);
            }
        }
        if (this.missingTargets.size() > 0) {
            log.warn("Found {} missing targets (truth atoms without a matching target). Example: {}.", Integer.valueOf(this.missingTargets.size()), this.missingTargets.get(0));
        }
    }

    public Map<RandomVariableAtom, ObservedAtom> getLabelMap() {
        return Collections.unmodifiableMap(this.labelMap);
    }

    public Map<ObservedAtom, ObservedAtom> getObservedMap() {
        return Collections.unmodifiableMap(this.observedMap);
    }

    public List<RandomVariableAtom> getLatentVariables() {
        return Collections.unmodifiableList(this.latentVariables);
    }

    public List<ObservedAtom> getMissingLabels() {
        return Collections.unmodifiableList(this.missingLabels);
    }

    public List<ObservedAtom> getMissingTargets() {
        return Collections.unmodifiableList(this.missingTargets);
    }

    public Iterable<RandomVariableAtom> getAllPredictions() {
        return IteratorUtils.join(this.labelMap.keySet(), this.latentVariables);
    }

    public Iterable<GroundAtom> getAllTargets() {
        return IteratorUtils.join(this.labelMap.keySet(), this.observedMap.keySet(), this.latentVariables, this.missingLabels);
    }

    public Iterable<GroundAtom> getAllTruths() {
        return IteratorUtils.join(this.labelMap.values(), this.observedMap.values(), this.missingTargets);
    }

    public void addRandomVariableTargetAtom(RandomVariableAtom randomVariableAtom) {
        int indexOf = this.missingTargets.indexOf(randomVariableAtom);
        if (indexOf != -1) {
            this.labelMap.put(randomVariableAtom, this.missingTargets.remove(indexOf));
            return;
        }
        int indexOf2 = this.latentVariables.indexOf(randomVariableAtom);
        if (indexOf2 == -1) {
            this.latentVariables.add(randomVariableAtom);
        } else {
            this.latentVariables.set(indexOf2, randomVariableAtom);
        }
    }

    public void deleteAtom(GroundAtom groundAtom) {
        if (groundAtom instanceof RandomVariableAtom) {
            this.labelMap.remove((RandomVariableAtom) groundAtom);
            this.latentVariables.remove((RandomVariableAtom) groundAtom);
        } else {
            this.observedMap.remove((ObservedAtom) groundAtom);
            this.missingLabels.remove((ObservedAtom) groundAtom);
            this.missingTargets.remove((ObservedAtom) groundAtom);
        }
    }

    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getFullMap() {
        return IteratorUtils.join(this.labelMap.entrySet(), this.observedMap.entrySet());
    }

    public String toString() {
        return String.format("Training Map -- Label Map: %d, Observed Map: %d, Latent Variables: %d, Missing Labels: %d, Missing Targets: %d", Integer.valueOf(this.labelMap.size()), Integer.valueOf(this.observedMap.size()), Integer.valueOf(this.latentVariables.size()), Integer.valueOf(this.missingLabels.size()), Integer.valueOf(this.missingTargets.size()));
    }

    private void prefetchTruthAtoms(Database database) {
        Iterator<StandardPredicate> it = database.getDataStore().getRegisteredPredicates().iterator();
        while (it.hasNext()) {
            database.getAllGroundAtoms(it.next());
        }
    }
}
