package org.linqs.psl.grounding.collective;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/grounding/collective/Coverage.class */
public class Coverage {
    private static final Logger log = LoggerFactory.getLogger(Coverage.class);

    private Coverage() {
    }

    public static Set<CandidateQuery> compute(List<Rule> list, Set<CandidateQuery> set) {
        Containment.computeContainement(list, set);
        Set<CandidateQuery> greedySmartCoverage = greedySmartCoverage(list, set);
        if (greedySmartCoverage != null) {
            return greedySmartCoverage;
        }
        throw new IllegalStateException(String.format("Could not compute coverage. Collective Rules: %s, Candidates: %s.", list, set));
    }

    private static Set<CandidateQuery> greedySmartCoverage(List<Rule> list, Set<CandidateQuery> set) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashMap hashMap = new HashMap();
        for (CandidateQuery candidateQuery : set) {
            for (Rule rule : candidateQuery.getCoveredRules()) {
                if (!hashMap.containsKey(rule) || candidateQuery.getScore() < ((Double) hashMap.get(rule)).doubleValue()) {
                    hashMap.put(rule, Double.valueOf(candidateQuery.getScore()));
                }
            }
        }
        while (hashSet2.size() != list.size()) {
            double d = 0.0d;
            CandidateQuery candidateQuery2 = null;
            for (CandidateQuery candidateQuery3 : set) {
                if (!hashSet.contains(candidateQuery3)) {
                    double score = candidateQuery3.getScore();
                    boolean z = false;
                    for (Rule rule2 : candidateQuery3.getCoveredRules()) {
                        if (!hashSet2.contains(rule2)) {
                            score -= ((Double) hashMap.get(rule2)).doubleValue();
                            z = true;
                        }
                    }
                    if (z && (candidateQuery2 == null || score < d)) {
                        d = score;
                        candidateQuery2 = candidateQuery3;
                    }
                }
            }
            hashSet.add(candidateQuery2);
            hashSet2.addAll(candidateQuery2.getCoveredRules());
            if (hashSet2.size() == list.size()) {
                return hashSet;
            }
        }
        return null;
    }

    private static Set<CandidateQuery> greedyNaiveCoverage(List<Rule> list, Set<CandidateQuery> set) {
        ArrayList<CandidateQuery> arrayList = new ArrayList(set);
        Collections.sort(arrayList, new Comparator<CandidateQuery>() { // from class: org.linqs.psl.grounding.collective.Coverage.1
            @Override // java.util.Comparator
            public int compare(CandidateQuery candidateQuery, CandidateQuery candidateQuery2) {
                return MathUtils.compare(candidateQuery.getScore(), candidateQuery2.getScore());
            }
        });
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (CandidateQuery candidateQuery : arrayList) {
            if (!hashSet2.containsAll(candidateQuery.getCoveredRules())) {
                hashSet.add(candidateQuery);
                hashSet2.addAll(candidateQuery.getCoveredRules());
                if (hashSet2.size() == list.size()) {
                    return hashSet;
                }
            }
        }
        return null;
    }
}
