001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import org.tribuo.Prediction;
020import org.tribuo.classification.Label;
021import org.tribuo.classification.evaluation.ConfusionMatrix;
022import org.tribuo.classification.evaluation.LabelMetric;
023import org.tribuo.classification.evaluation.LabelMetrics;
024import org.tribuo.evaluation.metrics.EvaluationMetric;
025import org.tribuo.evaluation.metrics.MetricID;
026import org.tribuo.evaluation.metrics.MetricTarget;
027import org.tribuo.provenance.EvaluationProvenance;
028import org.tribuo.sequence.SequenceEvaluation;
029
030import java.util.ArrayList;
031import java.util.Collections;
032import java.util.List;
033import java.util.Map;
034import java.util.logging.Logger;
035
036/**
037 * A class that can be used to evaluate a sequence label classification model element wise on a given set of data.
038 */
039public class LabelSequenceEvaluation implements SequenceEvaluation<Label> {
040
041    private static final Logger logger = Logger.getLogger(LabelSequenceEvaluation.class.getName());
042    
043    private final Map<MetricID<Label>, Double> results;
044    private final LabelMetric.Context ctx;
045    private final ConfusionMatrix<Label> cm;
046    private final EvaluationProvenance provenance;
047
048    /**
049     * Constructs a LabelSequenceEvaluation using the supplied parameters.
050     * @param results The metric values.
051     * @param ctx The context.
052     * @param provenance The evaluation provenance.
053     */
054    protected LabelSequenceEvaluation(Map<MetricID<Label>, Double> results,
055                                      LabelMetric.Context ctx,
056                                      EvaluationProvenance provenance) {
057        this.results = results;
058        this.ctx = ctx;
059        this.cm = ctx.getCM();
060        this.provenance = provenance;
061    }
062
063    /**
064     * Gets the flattened predictions.
065     * @return The flattened predictions.
066     */
067    public List<Prediction<Label>> getPredictions() {
068        return ctx.getPredictions();
069    }
070
071    /**
072     * Gets the confusion matrix backing this evaluation.
073     * @return The confusion matrix.
074     */
075    public ConfusionMatrix<Label> getConfusionMatrix() {
076        return cm;
077    }
078
079    @Override
080    public Map<MetricID<Label>, Double> asMap() {
081        return Collections.unmodifiableMap(results);
082    }
083
084    /**
085     * Note: confusion is not stored in the underlying map, so it won't show up in aggregation.
086     * @param predictedLabel The predicted label.
087     * @param trueLabel The true label.
088     * @return The number of times that {@code predictedLabel} was predicted for <code>trueLabel</code>.
089     */
090    public double confusion(Label predictedLabel, Label trueLabel) {
091        return cm.confusion(predictedLabel, trueLabel);
092    }
093
094    /**
095     * Gets the true positive count for that label.
096     * @param label The label.
097     * @return The true positive count.
098     */
099    public double tp(Label label) {
100        return get(label, LabelMetrics.TP);
101    }
102
103    /**
104     * Gets the micro averaged true positive count.
105     * @return The micro averaged true positive count.
106     */
107    public double tp() {
108        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP);
109    }
110
111    /**
112     * Gets the macro averaged true positive count.
113     * @return The macro averaged true positive count.
114     */
115    public double macroTP() {
116        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP);
117    }
118
119    /**
120     * The false positive count for this label.
121     * @param label The label.
122     * @return The false positive count.
123     */
124    public double fp(Label label) {
125        return get(label, LabelMetrics.FP);
126    }
127
128    /**
129     * Gets the micro averaged false positive count.
130     * @return The micro averaged false positive count.
131     */
132    public double fp() {
133        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP);
134    }
135
136    /**
137     * Gets the macro averaged false positive count.
138     * @return The macro averaged false positive count.
139     */
140    public double macroFP() {
141        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP);
142    }
143
144    /**
145     * The true negative count for this label.
146     * @param label The label.
147     * @return The true negative count.
148     */
149    public double tn(Label label) {
150        return get(label, LabelMetrics.TN);
151    }
152
153    /**
154     * Gets the micro averaged true negative count.
155     * @return The micro averaged true negative count.
156     */
157    public double tn() {
158        return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN);
159    }
160
161    /**
162     * Gets the macro averaged true negative count.
163     * @return The macro averaged true negative count.
164     */
165    public double macroTN() {
166        return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN);
167    }
168
169    /**
170     * The false negative count for this label.
171     * @param label The label.
172     * @return The false negative count.
173     */
174    public double fn(Label label) {
175        return get(label, LabelMetrics.FN);
176    }
177
178    /**
179     * Gets the micro averaged false negative count.
180     * @return The micro averaged false negative count.
181     */
182    public double fn() {
183        return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN);
184    }
185
186    /**
187     * Gets the macro averaged false negative count.
188     * @return The macro averaged false negative count.
189     */
190    public double macroFN() {
191        return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN);
192    }
193
194    /**
195     * The precision for this label.
196     * @param label The label.
197     * @return The precision.
198     */
199    public double precision(Label label) {
200        return get(label, LabelMetrics.PRECISION);
201    }
202
203    /**
204     * The micro averaged precision.
205     * @return The micro averaged precision.
206     */
207    public double microAveragedPrecision() {
208        return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION);
209    }
210
211    /**
212     * The macro averaged precision.
213     * @return The macro averaged precision.
214     */
215    public double macroAveragedPrecision() {
216        return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION);
217    }
218
219    /**
220     * The recall for this label.
221     * @param label The label.
222     * @return The recall.
223     */
224    public double recall(Label label) {
225        return get(label, LabelMetrics.RECALL);
226    }
227
228    /**
229     * The micro averaged recall.
230     * @return The micro averaged recall.
231     */
232    public double microAveragedRecall() {
233        return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL);
234    }
235
236    /**
237     * The macro averaged recall.
238     * @return The macro averaged recall.
239     */
240    public double macroAveragedRecall() {
241        return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL);
242    }
243
244    /**
245     * The F1 for this label.
246     * @param label The label.
247     * @return The F1.
248     */
249    public double f1(Label label) {
250        return get(label, LabelMetrics.RECALL);
251    }
252
253    /**
254     * The micro averaged F1.
255     * @return The micro averaged F1.
256     */
257    public double microAveragedF1() {
258        return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1);
259    }
260
261    /**
262     * The macro averaged F1.
263     * @return The macro averaged F1.
264     */
265    public double macroAveragedF1() {
266        return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1);
267    }
268
269    /**
270     * The accuracy.
271     * @return The accuracy.
272     */
273    public double accuracy() {
274        return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY);
275    }
276
277    /**
278     * Gets the accuracy for this label.
279     * @param label The label.
280     * @return The accuracy.
281     */
282    public double accuracy(Label label) {
283        return get(label, LabelMetrics.ACCURACY);
284    }
285
286    /**
287     * Gets the balanced error rate.
288     * <p>
289     * Also known as 1 - the macro averaged recall.
290     * @return The balanced error rate.
291     */
292    public double balancedErrorRate() {
293        // Target doesn't matter for balanced error rate, so we just use Average.macro
294        // as it's the macro averaged recall.
295        return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE);
296    }
297
298    @Override
299    public EvaluationProvenance getProvenance() { return provenance; }
300
301    @Override
302    public String toString() {
303        List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
304        StringBuilder sb = new StringBuilder();
305        int tp = 0;
306        int fn = 0;
307        int fp = 0;
308        int n = 0;
309        //
310        // Figure out the biggest class label and therefore the format string
311        // that we should use for them.
312        int maxLabelSize = "Balanced Error Rate".length();
313        for(Label label : labelOrder) {
314            maxLabelSize = Math.max(maxLabelSize, label.getLabel().length());
315        }
316        String labelFormatString = String.format("%%-%ds", maxLabelSize+2);
317        sb.append(String.format(labelFormatString, "Class"));
318        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
319        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
320        for (Label label : labelOrder) {
321            if (cm.support(label) == 0) {
322                continue;
323            }
324            n += cm.support(label);
325            tp += cm.tp(label);
326            fn += cm.fn(label);
327            fp += cm.fp(label);
328            sb.append(String.format(labelFormatString, label));
329            sb.append(String.format("%,12d%,12d%,12d%,12d",
330                    (int) cm.support(label),
331                    (int) cm.tp(label),
332                    (int) cm.fn(label),
333                    (int) cm.fp(label)
334            ));
335            sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label)));
336        }
337        sb.append(String.format(labelFormatString, "Total"));
338        sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
339        sb.append(String.format(labelFormatString, "Accuracy"));
340        sb.append(String.format("%60.3f%n", (double) tp / n));
341        sb.append(String.format(labelFormatString, "Micro Average"));
342        sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1()));
343        sb.append(String.format(labelFormatString, "Macro Average"));
344        sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1()));
345        sb.append(String.format(labelFormatString, "Balanced Error Rate"));
346        sb.append(String.format("%60.3f", balancedErrorRate()));
347        return sb.toString();
348    }
349
350    private double get(MetricTarget<Label> tgt, LabelMetrics metric) {
351        return get(metric.forTarget(tgt).getID());
352    }
353
354    private double get(Label label, LabelMetrics metric) {
355        return get(metric
356                .forTarget(new MetricTarget<>(label))
357                .getID());
358    }
359
360    private double get(EvaluationMetric.Average avg, LabelMetrics metric) {
361        return get(metric
362                .forTarget(new MetricTarget<>(avg))
363                .getID());
364    }
365}