001/* 002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.classification.Label; 020import org.tribuo.evaluation.EvaluationRenderer; 021 022import java.util.ArrayList; 023import java.util.List; 024 025/** 026 * Adds multi-class classification specific metrics to {@link ClassifierEvaluation}. 027 */ 028public interface LabelEvaluation extends ClassifierEvaluation<Label> { 029 030 /** 031 * The overall accuracy of the evaluation. 032 * @return The accuracy. 033 */ 034 double accuracy(); 035 036 /** 037 * The per label accuracy of the evaluation. 038 * @param label The target label. 039 * @return The per label accuracy. 040 */ 041 double accuracy(Label label); 042 043 /** 044 * Area under the ROC curve. 045 * 046 * @param label target label 047 * @return AUC ROC score 048 * 049 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 050 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 051 */ 052 double AUCROC(Label label); 053 054 /** 055 * Area under the ROC curve averaged across labels. 056 * <p> 057 * If {@code weighted} is false, use a macro average, if true, weight by the evaluation's observed class counts. 058 * </p> 059 * 060 * @param weighted If true weight by the class counts, if false use a macro average. 061 * @return The average AUCROC. 062 * 063 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 064 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 065 */ 066 double averageAUCROC(boolean weighted); 067 068 /** 069 * Summarises a Precision-Recall Curve by taking the weighted mean of the 070 * precisions at a given threshold, where the weight is the recall achieved at 071 * that threshold. 072 * 073 * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[]) 074 * 075 * @param label The target label. 076 * @return The averaged precision for that label. 077 * 078 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 079 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 080 */ 081 double averagedPrecision(Label label); 082 083 /** 084 * Calculates the Precision Recall curve for a single label. 085 * 086 * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[]) 087 * 088 * @param label The target label. 089 * @return The precision recall curve for that label. 090 * 091 * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model 092 * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve. 093 */ 094 LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label); 095 096 /** 097 * Returns a HTML formatted String representing this evaluation. 098 * <p> 099 * Uses the label order of the confusion matrix, which can be used to display 100 * a subset of the per label metrics. When they are subset the total row 101 * represents only the subset selected, not all the predictions, however 102 * the accuracy and averaged metrics cover all the predictions. 103 * @return A HTML formatted String. 104 */ 105 default String toHTML() { 106 return LabelEvaluation.toHTML(this); 107 } 108 109 /** 110 * This method produces a nicely formatted String output, with 111 * appropriate tabs and newlines, suitable for display on a terminal. 112 * It can be used as an implementation of the {@link EvaluationRenderer} 113 * functional interface. 114 * <p> 115 * Uses the label order of the confusion matrix, which can be used to display 116 * a subset of the per label metrics. When they are subset the total row 117 * represents only the subset selected, not all the predictions, however 118 * the accuracy and averaged metrics cover all the predictions. 119 * @param evaluation The evaluation to format. 120 * @return Formatted output showing the main results of the evaluation. 121 */ 122 public static String toFormattedString(LabelEvaluation evaluation) { 123 ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix(); 124 List<Label> labelOrder = new ArrayList<>(cm.getLabelOrder()); 125 labelOrder.retainAll(cm.observed()); 126 StringBuilder sb = new StringBuilder(); 127 int tp = 0; 128 int fn = 0; 129 int fp = 0; 130 int n = 0; 131 // 132 // Figure out the biggest class label and therefore the format string 133 // that we should use for them. 134 int maxLabelSize = "Balanced Error Rate".length(); 135 for(Label label : labelOrder) { 136 maxLabelSize = Math.max(maxLabelSize, label.getLabel().length()); 137 } 138 String labelFormatString = String.format("%%-%ds", maxLabelSize+2); 139 sb.append(String.format(labelFormatString, "Class")); 140 sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp")); 141 sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1")); 142 for (Label label : labelOrder) { 143 if (cm.support(label) == 0) { 144 continue; 145 } 146 n += cm.support(label); 147 tp += cm.tp(label); 148 fn += cm.fn(label); 149 fp += cm.fp(label); 150 sb.append(String.format(labelFormatString, label)); 151 sb.append(String.format("%,12d%,12d%,12d%,12d", 152 (int) cm.support(label), 153 (int) cm.tp(label), 154 (int) cm.fn(label), 155 (int) cm.fp(label) 156 )); 157 sb.append(String.format("%12.3f%12.3f%12.3f%n", 158 evaluation.recall(label), 159 evaluation.precision(label), 160 evaluation.f1(label))); 161 } 162 sb.append(String.format(labelFormatString, "Total")); 163 sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp)); 164 sb.append(String.format(labelFormatString, "Accuracy")); 165 sb.append(String.format("%60.3f%n", evaluation.accuracy())); 166 sb.append(String.format(labelFormatString, "Micro Average")); 167 sb.append(String.format("%60.3f%12.3f%12.3f%n", 168 evaluation.microAveragedRecall(), 169 evaluation.microAveragedPrecision(), 170 evaluation.microAveragedF1())); 171 sb.append(String.format(labelFormatString, "Macro Average")); 172 sb.append(String.format("%60.3f%12.3f%12.3f%n", 173 evaluation.macroAveragedRecall(), 174 evaluation.macroAveragedPrecision(), 175 evaluation.macroAveragedF1())); 176 sb.append(String.format(labelFormatString, "Balanced Error Rate")); 177 sb.append(String.format("%60.3f", evaluation.balancedErrorRate())); 178 return sb.toString(); 179 } 180 181 /** 182 * This method produces a HTML formatted String output, with 183 * appropriate tabs and newlines, suitable for integration into a webpage. 184 * It can be used as an implementation of the {@link EvaluationRenderer} 185 * functional interface. 186 * <p> 187 * Uses the label order of the confusion matrix, which can be used to display 188 * a subset of the per label metrics. When they are subset the total row 189 * represents only the subset selected, not all the predictions, however 190 * the accuracy and averaged metrics cover all the predictions. 191 * @param evaluation The evaluation to format. 192 * @return Formatted HTML output showing the main results of the evaluation. 193 */ 194 public static String toHTML(LabelEvaluation evaluation) { 195 ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix(); 196 List<Label> labelOrder = cm.getLabelOrder(); 197 StringBuilder sb = new StringBuilder(); 198 int tp = 0; 199 int fn = 0; 200 int fp = 0; 201 int tn = 0; 202 sb.append("<table>\n"); 203 sb.append("<tr>\n"); 204 sb.append("<th>Class</th><th>n</th> <th>%</th> <th>tp</th> <th>fn</th> <th>fp</th> <th>Recall</th> <th>Precision</th> <th>F1</th>"); 205 sb.append("\n</tr>\n"); 206 // 207 // Compute the total number of instances first, so we can show proportions. 208 for (Label label : labelOrder) { 209 //tn += occurrences.getOrDefault(label, 0); 210 tn += cm.tn(label); 211 } 212 for (Label label : labelOrder) { 213 if (cm.support(label) == 0) { 214 continue; 215 } 216 tp += cm.tp(label); 217 fn += cm.fn(label); 218 fp += cm.fp(label); 219 sb.append("<tr>"); 220 sb.append("<td><code>").append(label).append("</code></td>"); 221 int occurrence = (int) cm.support(label); 222 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", occurrence)).append("</td>"); 223 sb.append("<td style=\"text-align:right\">").append(String.format("%8.1f%%", (occurrence/ (double) tn)*100)).append("</td>"); 224 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.tp(label))).append("</td>"); 225 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fn(label))).append("</td>"); 226 sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fp(label))).append("</td>"); 227 sb.append(String.format("<td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n", 228 evaluation.recall(label), evaluation.precision(label), evaluation.f1(label))); 229 sb.append("</tr>"); 230 } 231 sb.append("<tr>"); 232 sb.append("<td>Total</td>"); 233 sb.append(String.format("<td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\"></td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td>%n", tn, tp, fn, fp)); 234 sb.append("<td colspan=\"4\"></td>"); 235 sb.append("</tr>\n<tr>"); 236 sb.append(String.format("<td>Accuracy</td><td style=\"text-align:right\" colspan=\"6\">%8.3f</td>%n", evaluation.accuracy())); 237 sb.append("<td colspan=\"4\"></td>"); 238 sb.append("</tr>\n<tr>"); 239 sb.append("<td>Micro Average</td>"); 240 sb.append(String.format("<td style=\"text-align:right\" colspan=\"6\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n", 241 evaluation.microAveragedRecall(), 242 evaluation.microAveragedPrecision(), 243 evaluation.microAveragedF1())); 244 sb.append("</tr></table>"); 245 return sb.toString(); 246 } 247 248}