001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence; 018 019import org.tribuo.Prediction; 020import org.tribuo.classification.Label; 021import org.tribuo.classification.evaluation.ConfusionMatrix; 022import org.tribuo.classification.evaluation.LabelMetric; 023import org.tribuo.classification.evaluation.LabelMetrics; 024import org.tribuo.evaluation.metrics.EvaluationMetric; 025import org.tribuo.evaluation.metrics.MetricID; 026import org.tribuo.evaluation.metrics.MetricTarget; 027import org.tribuo.provenance.EvaluationProvenance; 028import org.tribuo.sequence.SequenceEvaluation; 029 030import java.util.ArrayList; 031import java.util.Collections; 032import java.util.List; 033import java.util.Map; 034import java.util.logging.Logger; 035 036/** 037 * A class that can be used to evaluate a sequence label classification model element wise on a given set of data. 038 */ 039public class LabelSequenceEvaluation implements SequenceEvaluation<Label> { 040 041 private static final Logger logger = Logger.getLogger(LabelSequenceEvaluation.class.getName()); 042 043 private final Map<MetricID<Label>, Double> results; 044 private final LabelMetric.Context ctx; 045 private final ConfusionMatrix<Label> cm; 046 private final EvaluationProvenance provenance; 047 048 /** 049 * Constructs a LabelSequenceEvaluation using the supplied parameters. 050 * @param results The metric values. 051 * @param ctx The context. 052 * @param provenance The evaluation provenance. 053 */ 054 protected LabelSequenceEvaluation(Map<MetricID<Label>, Double> results, 055 LabelMetric.Context ctx, 056 EvaluationProvenance provenance) { 057 this.results = results; 058 this.ctx = ctx; 059 this.cm = ctx.getCM(); 060 this.provenance = provenance; 061 } 062 063 /** 064 * Gets the flattened predictions. 065 * @return The flattened predictions. 066 */ 067 public List<Prediction<Label>> getPredictions() { 068 return ctx.getPredictions(); 069 } 070 071 /** 072 * Gets the confusion matrix backing this evaluation. 073 * @return The confusion matrix. 074 */ 075 public ConfusionMatrix<Label> getConfusionMatrix() { 076 return cm; 077 } 078 079 @Override 080 public Map<MetricID<Label>, Double> asMap() { 081 return Collections.unmodifiableMap(results); 082 } 083 084 /** 085 * Note: confusion is not stored in the underlying map, so it won't show up in aggregation. 086 * @param predictedLabel The predicted label. 087 * @param trueLabel The true label. 088 * @return The number of times that {@code predictedLabel} was predicted for <code>trueLabel</code>. 089 */ 090 public double confusion(Label predictedLabel, Label trueLabel) { 091 return cm.confusion(predictedLabel, trueLabel); 092 } 093 094 /** 095 * Gets the true positive count for that label. 096 * @param label The label. 097 * @return The true positive count. 098 */ 099 public double tp(Label label) { 100 return get(label, LabelMetrics.TP); 101 } 102 103 /** 104 * Gets the micro averaged true positive count. 105 * @return The micro averaged true positive count. 106 */ 107 public double tp() { 108 return get(EvaluationMetric.Average.MICRO, LabelMetrics.TP); 109 } 110 111 /** 112 * Gets the macro averaged true positive count. 113 * @return The macro averaged true positive count. 114 */ 115 public double macroTP() { 116 return get(EvaluationMetric.Average.MACRO, LabelMetrics.TP); 117 } 118 119 /** 120 * The false positive count for this label. 121 * @param label The label. 122 * @return The false positive count. 123 */ 124 public double fp(Label label) { 125 return get(label, LabelMetrics.FP); 126 } 127 128 /** 129 * Gets the micro averaged false positive count. 130 * @return The micro averaged false positive count. 131 */ 132 public double fp() { 133 return get(EvaluationMetric.Average.MICRO, LabelMetrics.FP); 134 } 135 136 /** 137 * Gets the macro averaged false positive count. 138 * @return The macro averaged false positive count. 139 */ 140 public double macroFP() { 141 return get(EvaluationMetric.Average.MACRO, LabelMetrics.FP); 142 } 143 144 /** 145 * The true negative count for this label. 146 * @param label The label. 147 * @return The true negative count. 148 */ 149 public double tn(Label label) { 150 return get(label, LabelMetrics.TN); 151 } 152 153 /** 154 * Gets the micro averaged true negative count. 155 * @return The micro averaged true negative count. 156 */ 157 public double tn() { 158 return get(EvaluationMetric.Average.MICRO, LabelMetrics.TN); 159 } 160 161 /** 162 * Gets the macro averaged true negative count. 163 * @return The macro averaged true negative count. 164 */ 165 public double macroTN() { 166 return get(EvaluationMetric.Average.MACRO, LabelMetrics.TN); 167 } 168 169 /** 170 * The false negative count for this label. 171 * @param label The label. 172 * @return The false negative count. 173 */ 174 public double fn(Label label) { 175 return get(label, LabelMetrics.FN); 176 } 177 178 /** 179 * Gets the micro averaged false negative count. 180 * @return The micro averaged false negative count. 181 */ 182 public double fn() { 183 return get(EvaluationMetric.Average.MICRO, LabelMetrics.FN); 184 } 185 186 /** 187 * Gets the macro averaged false negative count. 188 * @return The macro averaged false negative count. 189 */ 190 public double macroFN() { 191 return get(EvaluationMetric.Average.MACRO, LabelMetrics.FN); 192 } 193 194 /** 195 * The precision for this label. 196 * @param label The label. 197 * @return The precision. 198 */ 199 public double precision(Label label) { 200 return get(label, LabelMetrics.PRECISION); 201 } 202 203 /** 204 * The micro averaged precision. 205 * @return The micro averaged precision. 206 */ 207 public double microAveragedPrecision() { 208 return get(EvaluationMetric.Average.MICRO, LabelMetrics.PRECISION); 209 } 210 211 /** 212 * The macro averaged precision. 213 * @return The macro averaged precision. 214 */ 215 public double macroAveragedPrecision() { 216 return get(EvaluationMetric.Average.MACRO, LabelMetrics.PRECISION); 217 } 218 219 /** 220 * The recall for this label. 221 * @param label The label. 222 * @return The recall. 223 */ 224 public double recall(Label label) { 225 return get(label, LabelMetrics.RECALL); 226 } 227 228 /** 229 * The micro averaged recall. 230 * @return The micro averaged recall. 231 */ 232 public double microAveragedRecall() { 233 return get(EvaluationMetric.Average.MICRO, LabelMetrics.RECALL); 234 } 235 236 /** 237 * The macro averaged recall. 238 * @return The macro averaged recall. 239 */ 240 public double macroAveragedRecall() { 241 return get(EvaluationMetric.Average.MACRO, LabelMetrics.RECALL); 242 } 243 244 /** 245 * The F1 for this label. 246 * @param label The label. 247 * @return The F1. 248 */ 249 public double f1(Label label) { 250 return get(label, LabelMetrics.RECALL); 251 } 252 253 /** 254 * The micro averaged F1. 255 * @return The micro averaged F1. 256 */ 257 public double microAveragedF1() { 258 return get(EvaluationMetric.Average.MICRO, LabelMetrics.F1); 259 } 260 261 /** 262 * The macro averaged F1. 263 * @return The macro averaged F1. 264 */ 265 public double macroAveragedF1() { 266 return get(EvaluationMetric.Average.MACRO, LabelMetrics.F1); 267 } 268 269 /** 270 * The accuracy. 271 * @return The accuracy. 272 */ 273 public double accuracy() { 274 return get(EvaluationMetric.Average.MICRO, LabelMetrics.ACCURACY); 275 } 276 277 /** 278 * Gets the accuracy for this label. 279 * @param label The label. 280 * @return The accuracy. 281 */ 282 public double accuracy(Label label) { 283 return get(label, LabelMetrics.ACCURACY); 284 } 285 286 /** 287 * Gets the balanced error rate. 288 * <p> 289 * Also known as 1 - the macro averaged recall. 290 * @return The balanced error rate. 291 */ 292 public double balancedErrorRate() { 293 // Target doesn't matter for balanced error rate, so we just use Average.macro 294 // as it's the macro averaged recall. 295 return get(EvaluationMetric.Average.MACRO, LabelMetrics.BALANCED_ERROR_RATE); 296 } 297 298 @Override 299 public EvaluationProvenance getProvenance() { return provenance; } 300 301 @Override 302 public String toString() { 303 List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain()); 304 StringBuilder sb = new StringBuilder(); 305 int tp = 0; 306 int fn = 0; 307 int fp = 0; 308 int n = 0; 309 // 310 // Figure out the biggest class label and therefore the format string 311 // that we should use for them. 312 int maxLabelSize = "Balanced Error Rate".length(); 313 for(Label label : labelOrder) { 314 maxLabelSize = Math.max(maxLabelSize, label.getLabel().length()); 315 } 316 String labelFormatString = String.format("%%-%ds", maxLabelSize+2); 317 sb.append(String.format(labelFormatString, "Class")); 318 sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp")); 319 sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1")); 320 for (Label label : labelOrder) { 321 if (cm.support(label) == 0) { 322 continue; 323 } 324 n += cm.support(label); 325 tp += cm.tp(label); 326 fn += cm.fn(label); 327 fp += cm.fp(label); 328 sb.append(String.format(labelFormatString, label)); 329 sb.append(String.format("%,12d%,12d%,12d%,12d", 330 (int) cm.support(label), 331 (int) cm.tp(label), 332 (int) cm.fn(label), 333 (int) cm.fp(label) 334 )); 335 sb.append(String.format("%12.3f%12.3f%12.3f%n", recall(label), precision(label), f1(label))); 336 } 337 sb.append(String.format(labelFormatString, "Total")); 338 sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp)); 339 sb.append(String.format(labelFormatString, "Accuracy")); 340 sb.append(String.format("%60.3f%n", (double) tp / n)); 341 sb.append(String.format(labelFormatString, "Micro Average")); 342 sb.append(String.format("%60.3f%12.3f%12.3f%n", microAveragedRecall(), microAveragedPrecision(), microAveragedF1())); 343 sb.append(String.format(labelFormatString, "Macro Average")); 344 sb.append(String.format("%60.3f%12.3f%12.3f%n", macroAveragedRecall(), macroAveragedPrecision(), macroAveragedF1())); 345 sb.append(String.format(labelFormatString, "Balanced Error Rate")); 346 sb.append(String.format("%60.3f", balancedErrorRate())); 347 return sb.toString(); 348 } 349 350 private double get(MetricTarget<Label> tgt, LabelMetrics metric) { 351 return get(metric.forTarget(tgt).getID()); 352 } 353 354 private double get(Label label, LabelMetrics metric) { 355 return get(metric 356 .forTarget(new MetricTarget<>(label)) 357 .getID()); 358 } 359 360 private double get(EvaluationMetric.Average avg, LabelMetrics metric) { 361 return get(metric 362 .forTarget(new MetricTarget<>(avg)) 363 .getID()); 364 } 365}