package org.linqs.psl.evaluation.statistics;

import java.util.AbstractMap;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.atom.UnmanagedObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/Evaluator.class */
public abstract class Evaluator {
    protected boolean includeObserved = Options.EVAL_INCLUDE_OBS.getBoolean();
    protected boolean closeTruth = Options.EVAL_CLOSE_TRUTH.getBoolean();

    public boolean getIncludeObserved() {
        return this.includeObserved;
    }

    public void setIncludeObserved(boolean z) {
        this.includeObserved = z;
    }

    public boolean getCloseTruth() {
        return this.closeTruth;
    }

    public void setCloseTruth(boolean z) {
        this.closeTruth = z;
    }

    public abstract void compute(TrainingMap trainingMap);

    public abstract void compute(TrainingMap trainingMap, StandardPredicate standardPredicate);

    public abstract double getRepMetric();

    public abstract boolean isHigherRepBetter();

    public double getNormalizedRepMetric() {
        double repMetric = getRepMetric();
        if (!isHigherRepBetter()) {
            repMetric = -repMetric;
        }
        return repMetric;
    }

    public double getNormalizedMaxRepMetric() {
        double bestRepScore = getBestRepScore();
        if (!isHigherRepBetter()) {
            bestRepScore = -bestRepScore;
        }
        return bestRepScore;
    }

    public abstract double getBestRepScore();

    public abstract String getAllStats();

    public void compute(Database database, Database database2, StandardPredicate standardPredicate) {
        compute(new TrainingMap(database, database2), standardPredicate);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [java.lang.Iterable] */
    /* JADX WARN: Type inference failed for: r0v20, types: [java.lang.Iterable] */
    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getMap(TrainingMap trainingMap) {
        Set<Map.Entry<RandomVariableAtom, ObservedAtom>> entrySet = trainingMap.getLabelMap().entrySet();
        if (this.includeObserved) {
            entrySet = IteratorUtils.join(entrySet, trainingMap.getObservedMap().entrySet());
        }
        if (this.closeTruth) {
            entrySet = IteratorUtils.join(entrySet, IteratorUtils.map(trainingMap.getLatentVariables(), new IteratorUtils.MapFunction<GroundAtom, Map.Entry<GroundAtom, GroundAtom>>() { // from class: org.linqs.psl.evaluation.statistics.Evaluator.1
                @Override // org.linqs.psl.util.IteratorUtils.MapFunction
                public Map.Entry<GroundAtom, GroundAtom> map(GroundAtom groundAtom) {
                    return new AbstractMap.SimpleEntry(groundAtom, new UnmanagedObservedAtom(groundAtom.getPredicate(), groundAtom.getArguments(), 0.0f));
                }
            }));
        }
        return entrySet;
    }

    public Iterable<GroundAtom> getTargets(TrainingMap trainingMap) {
        return this.includeObserved ? trainingMap.getAllTargets() : trainingMap.getAllPredictions();
    }
}
