package org.linqs.psl.evaluation.statistics;

import java.util.Map;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Option;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;

/* loaded from: input_file:org/linqs/psl/evaluation/statistics/ContinuousEvaluator.class */
public class ContinuousEvaluator extends Evaluator {
    private RepresentativeMetric representative;
    private int count;
    private double absoluteError;
    private double squaredError;

    /* renamed from: org.linqs.psl.evaluation.statistics.ContinuousEvaluator$1, reason: invalid class name */
    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/ContinuousEvaluator$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$linqs$psl$evaluation$statistics$ContinuousEvaluator$RepresentativeMetric = new int[RepresentativeMetric.values().length];

        static {
            try {
                $SwitchMap$org$linqs$psl$evaluation$statistics$ContinuousEvaluator$RepresentativeMetric[RepresentativeMetric.MAE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$linqs$psl$evaluation$statistics$ContinuousEvaluator$RepresentativeMetric[RepresentativeMetric.MSE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:org/linqs/psl/evaluation/statistics/ContinuousEvaluator$RepresentativeMetric.class */
    public enum RepresentativeMetric {
        MAE,
        MSE
    }

    public ContinuousEvaluator() {
        this(Options.EVAL_CONT_REPRESENTATIVE.getString());
    }

    public ContinuousEvaluator(String str) {
        this(RepresentativeMetric.valueOf(str.toUpperCase()));
    }

    public ContinuousEvaluator(RepresentativeMetric representativeMetric) {
        this.representative = representativeMetric;
        this.count = 0;
        this.absoluteError = 0.0d;
        this.squaredError = 0.0d;
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap) {
        compute(trainingMap, null);
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public void compute(TrainingMap trainingMap, StandardPredicate standardPredicate) {
        this.count = 0;
        this.absoluteError = 0.0d;
        this.squaredError = 0.0d;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : trainingMap.getLabelMap().entrySet()) {
            if (standardPredicate == null || entry.getKey().getPredicate() == standardPredicate) {
                this.count++;
                this.absoluteError += Math.abs(entry.getValue().getValue() - entry.getKey().getValue());
                this.squaredError += Math.pow(entry.getValue().getValue() - entry.getKey().getValue(), 2.0d);
            }
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public double getRepMetric() {
        switch (AnonymousClass1.$SwitchMap$org$linqs$psl$evaluation$statistics$ContinuousEvaluator$RepresentativeMetric[this.representative.ordinal()]) {
            case 1:
                return mae();
            case Option.FLAG_POSITIVE /* 2 */:
                return mse();
            default:
                throw new IllegalStateException("Unknown representative metric: " + this.representative);
        }
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public boolean isHigherRepBetter() {
        return false;
    }

    public double mae() {
        if (this.count == 0) {
            return 0.0d;
        }
        return this.absoluteError / this.count;
    }

    public double mse() {
        if (this.count == 0) {
            return 0.0d;
        }
        return this.squaredError / this.count;
    }

    @Override // org.linqs.psl.evaluation.statistics.Evaluator
    public String getAllStats() {
        return String.format("MAE: %f, MSE: %f", Double.valueOf(mae()), Double.valueOf(mse()));
    }
}
