package org.linqs.psl.evaluation.statistics;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.StringUtils;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/CategoricalEvaluator.class */
public class CategoricalEvaluator extends Evaluator {
    private static final Logger log;
    public static final String DELIM = ":";
    private Set<Integer> virtualCategoryIndexes;
    private RepresentativeMetric representative;
    private String defaultPredicate;
    private int hits;
    private int misses;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/CategoricalEvaluator$RepresentativeMetric.class */
    public enum RepresentativeMetric {
        ACCURACY
    }

    public CategoricalEvaluator() {
        this(RepresentativeMetric.valueOf(Options.EVAL_CAT_REPRESENTATIVE.getString()), StringUtils.splitInt(Options.EVAL_CAT_CATEGORY_INDEXES.getString(), ":"));
    }

    public CategoricalEvaluator(int... iArr) {
        this(Options.EVAL_CAT_REPRESENTATIVE.getString(), iArr);
    }

    public CategoricalEvaluator(String str, int... iArr) {
        this(RepresentativeMetric.valueOf(str.toUpperCase()), iArr);
    }

    public CategoricalEvaluator(RepresentativeMetric representativeMetric, int... iArr) {
        this.representative = representativeMetric;
        setVirtualCategoryIndexes(iArr);
        this.defaultPredicate = Options.EVAL_CAT_DEFAULT_PREDICATE.getString();
        this.hits = 0;
        this.misses = 0;
    }

    public void setVirtualCategoryIndexes(int... iArr) {
        if (iArr == null || iArr.length == 0) {
            throw new IllegalArgumentException("Found no category indexes.");
        }
        this.virtualCategoryIndexes = new HashSet(iArr.length);
        for (int i : iArr) {
            this.virtualCategoryIndexes.add(Integer.valueOf(i));
        }
        log.debug("Virtual category indexes: [{}].", StringUtils.join(", ", this.virtualCategoryIndexes.toArray()));
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap) {
        if (this.defaultPredicate == null) {
            throw new UnsupportedOperationException("CategoricalEvaluators must have a default predicate set (through config).");
        }
        compute(trainingMap, StandardPredicate.get(this.defaultPredicate));
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        if (!$assertionsDisabled && standardPredicate == null) {
            throw new AssertionError();
        }
        this.hits = 0;
        this.misses = 0;
        Set<GroundAtom> predictedCategories = getPredictedCategories(trainingMap, standardPredicate);
        for (GroundAtom groundAtom : trainingMap.getAllTruths()) {
            if (groundAtom.getPredicate() == standardPredicate && groundAtom.getValue() >= 1.0d) {
                if (predictedCategories.contains(groundAtom)) {
                    this.hits++;
                } else {
                    this.misses++;
                }
            }
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getRepMetric() {
        switch (this.representative) {
            case ACCURACY:
                return accuracy();
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getBestRepScore() {
        switch (this.representative) {
            case ACCURACY:
                return 1.0d;
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public boolean isHigherRepBetter() {
        return true;
    }

    public double accuracy() {
        if (this.hits + this.misses == 0) {
            return 0.0d;
        }
        return this.hits / (this.hits + this.misses);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public String getAllStats() {
        return String.format("Categorical Accuracy: %f", Double.valueOf(accuracy()));
    }

    private Set<Integer> getTrueCategoryIndexes(StandardPredicate standardPredicate) {
        HashSet hashSet = new HashSet();
        Iterator<Integer> it = this.virtualCategoryIndexes.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue < 0) {
                intValue += standardPredicate.getArity();
            }
            if (intValue < 0 || intValue >= standardPredicate.getArity()) {
                throw new RuntimeException(String.format("Categorical index (%d) out of bounds for %s/%d.", Integer.valueOf(intValue), standardPredicate.getName(), Integer.valueOf(standardPredicate.getArity())));
            }
            hashSet.add(Integer.valueOf(intValue));
        }
        log.trace("True category indexes for {}: [{}].", standardPredicate.getName(), StringUtils.join(", ", hashSet.toArray()));
        return hashSet;
    }

    protected Set<GroundAtom> getPredictedCategories(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        Object obj = null;
        Set<Integer> trueCategoryIndexes = getTrueCategoryIndexes(standardPredicate);
        for (GroundAtom groundAtom : getTargets(trainingMap)) {
            if (groundAtom.getPredicate() == standardPredicate) {
                obj = (Map) putPredictedCategories(obj, groundAtom, 0, trueCategoryIndexes);
            }
        }
        HashSet hashSet = new HashSet();
        collectPredictedCategories(obj, hashSet);
        return hashSet;
    }

    private Object putPredictedCategories(Object obj, GroundAtom groundAtom, int i, Set<Integer> set) {
        if (!$assertionsDisabled && i > groundAtom.getArity()) {
            throw new AssertionError();
        }
        if (set.contains(Integer.valueOf(i))) {
            return putPredictedCategories(obj, groundAtom, i + 1, set);
        }
        if (i != groundAtom.getArity()) {
            Map hashMap = obj == null ? new HashMap() : (Map) obj;
            Constant constant = groundAtom.getArguments()[i];
            hashMap.put(constant, putPredictedCategories(hashMap.get(constant), groundAtom, i + 1, set));
            return hashMap;
        }
        if (obj == null) {
            return groundAtom;
        }
        GroundAtom groundAtom2 = (GroundAtom) obj;
        if (groundAtom.getValue() > groundAtom2.getValue()) {
            return groundAtom;
        }
        if (MathUtils.equals(groundAtom.getValue(), groundAtom2.getValue()) && RandUtils.nextBoolean()) {
            return groundAtom;
        }
        return groundAtom2;
    }

    private void collectPredictedCategories(Map<Constant, Object> map, Set<GroundAtom> set) {
        for (Object obj : map.values()) {
            if (obj instanceof GroundAtom) {
                set.add((GroundAtom) obj);
            } else {
                collectPredictedCategories((Map) obj, set);
            }
        }
    }

    static {
        $assertionsDisabled = !CategoricalEvaluator.class.desiredAssertionStatus();
        log = Logger.getLogger(CategoricalEvaluator.class);
    }
}
