001/* 002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.ImmutableOutputInfo; 020import org.tribuo.Model; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.math.la.DenseMatrix; 024 025import java.util.ArrayList; 026import java.util.Collections; 027import java.util.HashMap; 028import java.util.HashSet; 029import java.util.LinkedHashSet; 030import java.util.List; 031import java.util.Map; 032import java.util.Set; 033import java.util.function.ToDoubleFunction; 034import java.util.logging.Logger; 035 036/** 037 * A confusion matrix for {@link Label}s. 038 * <p> 039 * We interpret it as follows: 040 * 041 * {@code 042 * C[i, j] = k 043 * } 044 * 045 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times". 046 * 047 * <p> 048 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to 049 * the ground truth. 050 * </p> 051 */ 052public final class LabelConfusionMatrix implements ConfusionMatrix<Label> { 053 054 private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName()); 055 056 private final ImmutableOutputInfo<Label> domain; 057 058 private final int total; 059 private final Map<Label, Double> occurrences; 060 061 private final Set<Label> observed; 062 063 private final DenseMatrix cm; 064 065 private List<Label> labelOrder; 066 067 /** 068 * Creates a confusion matrix from the supplied predictions, using the label info 069 * from the supplied model. 070 * 071 * @param model The model to use for the label information. 072 * @param predictions The predictions. 073 */ 074 public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) { 075 this(model.getOutputIDInfo(), predictions); 076 } 077 078 /** 079 * Creates a confusion matrix from the supplied predictions and label info. 080 * 081 * @param domain The label information. 082 * @param predictions The predictions. 083 * @throws IllegalArgumentException If the domain doesn't contain all the predictions. 084 */ 085 public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) { 086 this.domain = domain; 087 this.total = predictions.size(); 088 this.cm = new DenseMatrix(domain.size(), domain.size()); 089 this.occurrences = new HashMap<>(); 090 this.observed = new HashSet<>(); 091 this.labelOrder = Collections.unmodifiableList(new ArrayList<>(domain.getDomain())); 092 tabulate(predictions); 093 } 094 095 /** 096 * Aggregate the predictions into this confusion matrix. 097 * 098 * @param predictions The predictions to aggregate. 099 */ 100 private void tabulate(List<Prediction<Label>> predictions) { 101 predictions.forEach(prediction -> { 102 Label y = prediction.getExample().getOutput(); 103 Label p = prediction.getOutput(); 104 // 105 // Check that the ground truth label is valid 106 if (y.getLabel().equals(Label.UNKNOWN)) { 107 throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate."); 108 } 109 occurrences.merge(y, 1d, Double::sum); 110 observed.add(y); 111 observed.add(p); 112 int iy = getIDOrThrow(y); 113 int ip = getIDOrThrow(p); 114 cm.add(ip, iy, 1d); 115 }); 116 } 117 118 @Override 119 public ImmutableOutputInfo<Label> getDomain() { 120 return domain; 121 } 122 123 @Override 124 public Set<Label> observed() { 125 return Collections.unmodifiableSet(observed); 126 } 127 128 @Override 129 public double support() { 130 return total; 131 } 132 133 @Override 134 public double support(Label label) { 135 return occurrences.getOrDefault(label, 0d); 136 } 137 138 @Override 139 public double tp(Label cls) { 140 return compute(cls, (i) -> cm.get(i, i)); 141 } 142 143 @Override 144 public double fp(Label cls) { 145 // Row-wise sum less true positives 146 return compute(cls, i -> cm.rowSum(i) - cm.get(i, i)); 147 } 148 149 @Override 150 public double fn(Label cls) { 151 // Column-wise sum less true positives 152 return compute(cls, i -> cm.columnSum(i) - cm.get(i, i)); 153 } 154 155 @Override 156 public double tn(Label cls) { 157 int n = getDomain().size(); 158 int i = getDomain().getID(cls); 159 double total = 0d; 160 for (int j = 0; j < n; j++) { 161 if (j == i) { 162 continue; 163 } 164 for (int k = 0; k < n; k++) { 165 if (k == i) { 166 continue; 167 } 168 total += cm.get(j, k); 169 } 170 } 171 return total; 172 } 173 174 @Override 175 public double confusion(Label predicted, Label trueClass) { 176 int i = getDomain().getID(predicted); 177 int j = getDomain().getID(trueClass); 178 return cm.get(i, j); 179 } 180 181 /** 182 * A convenience method for extracting the appropriate label statistic. 183 * 184 * @param cls The label to check. 185 * @param getter The get function which accepts a label id. 186 * @return The statistic for that label id. 187 */ 188 private double compute(Label cls, ToDoubleFunction<Integer> getter) { 189 int i = getDomain().getID(cls); 190 if (i < 0) { 191 logger.fine("Unknown Label " + cls); 192 return 0d; 193 } 194 return getter.applyAsDouble(i); 195 } 196 197 /** 198 * Gets the id for the supplied label, or throws an {@link IllegalArgumentException} if it's 199 * an unknown label. 200 * 201 * @param key The label. 202 * @return The int id for that label. 203 */ 204 private int getIDOrThrow(Label key) { 205 int id = domain.getID(key); 206 if (id < 0) { 207 throw new IllegalArgumentException("Unknown label: " + key); 208 } 209 return id; 210 } 211 212 /** 213 * Sets the label order used in {@link #toString}. 214 * <p> 215 * If the label order is a subset of the labels in the domain, only the 216 * labels present in the label order will be displayed. 217 * 218 * @param newLabelOrder The label order to use. 219 */ 220 @Override 221 public void setLabelOrder(List<Label> newLabelOrder) { 222 if (newLabelOrder == null || newLabelOrder.isEmpty()) { 223 throw new IllegalArgumentException("Label order must be non-null and non-empty."); 224 } 225 this.labelOrder = Collections.unmodifiableList(new ArrayList<>(newLabelOrder)); 226 } 227 228 /** 229 * Gets the current label order. 230 * 231 * May trigger order instantiation if the label order has not been set. 232 * @return The label order. 233 */ 234 public List<Label> getLabelOrder() { 235 return labelOrder; 236 } 237 238 @Override 239 public String toString() { 240 List<Label> curOrder = new ArrayList<>(labelOrder); 241 curOrder.retainAll(observed); 242 243 int maxLen = Integer.MIN_VALUE; 244 for (Label label : curOrder) { 245 maxLen = Math.max(label.getLabel().length(), maxLen); 246 maxLen = Math.max(String.format(" %,d", (int)(double)occurrences.getOrDefault(label,0.0)).length(), maxLen); 247 } 248 249 StringBuilder sb = new StringBuilder(); 250 String trueLabelFormat = String.format("%%-%ds", maxLen + 2); 251 String predictedLabelFormat = String.format("%%%ds", maxLen + 2); 252 String countFormat = String.format("%%,%dd", maxLen + 2); 253 254 // 255 // Empty spot in first row for labels on subsequent rows. 256 sb.append(String.format(trueLabelFormat, "")); 257 258 // 259 // Labels across the top for predicted. 260 for (Label predictedLabel : curOrder) { 261 sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel())); 262 } 263 sb.append('\n'); 264 265 for (Label trueLabel : curOrder) { 266 sb.append(String.format(trueLabelFormat, trueLabel.getLabel())); 267 for (Label predictedLabel : curOrder) { 268 int confusion = (int) confusion(predictedLabel, trueLabel); 269 sb.append(String.format(countFormat, confusion)); 270 } 271 sb.append('\n'); 272 } 273 return sb.toString(); 274 } 275 276 /** 277 * Emits a HTML table representation of the Confusion Matrix. 278 * @return The confusion matrix as a HTML table. 279 */ 280 public String toHTML() { 281 Set<Label> labelsToPrint = new LinkedHashSet<>(labelOrder); 282 labelsToPrint.retainAll(observed); 283 StringBuilder sb = new StringBuilder(); 284 sb.append("<table>\n"); 285 sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", occurrences.size() + 1)); 286 sb.append("<tr><th></th>"); 287 for (Label predictedLabel : labelsToPrint) { 288 sb.append("<th style=\"text-align:right\">") 289 .append(predictedLabel) 290 .append("</th>"); 291 } 292 sb.append("<th style=\"text-align:right\">Total</th>"); 293 sb.append("</tr>\n"); 294 for (Label trueLabel : labelsToPrint) { 295 sb.append("<tr><th>").append(trueLabel).append("</th>"); 296 double count = occurrences.getOrDefault(trueLabel, 0d); 297 for (Label predictedLabel : labelsToPrint) { 298 double tlmc = confusion(predictedLabel,trueLabel); 299 double percent = (tlmc / count) * 100; 300 sb.append("<td style=\"text-align:right\">") 301 .append(String.format("%,d (%.1f%%)", (int)tlmc, percent)) 302 .append("</td>"); 303 } 304 sb.append("<td style=\"text-align:right\">").append(count).append("</td>"); 305 sb.append("</tr>\n"); 306 } 307 sb.append("</table>"); 308 return sb.toString(); 309 } 310}