001/*
002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.classification.Label;
020import org.tribuo.evaluation.EvaluationRenderer;
021
022import java.util.ArrayList;
023import java.util.List;
024
025/**
026 * Adds multi-class classification specific metrics to {@link ClassifierEvaluation}.
027 */
028public interface LabelEvaluation extends ClassifierEvaluation<Label> {
029
030    /**
031     * The overall accuracy of the evaluation.
032     * @return The accuracy.
033     */
034    double accuracy();
035
036    /**
037     * The per label accuracy of the evaluation.
038     * @param label The target label.
039     * @return The per label accuracy.
040     */
041    double accuracy(Label label);
042
043    /**
044     * Area under the ROC curve.
045     *
046     * @param label target label
047     * @return AUC ROC score
048     *
049     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
050     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
051     */
052    double AUCROC(Label label);
053
054    /**
055     * Area under the ROC curve averaged across labels.
056     * <p>
057     * If {@code weighted} is false, use a macro average, if true, weight by the evaluation's observed class counts.
058     * </p>
059     *
060     * @param weighted If true weight by the class counts, if false use a macro average.
061     * @return The average AUCROC.
062     *
063     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
064     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
065     */
066    double averageAUCROC(boolean weighted);
067
068    /**
069     * Summarises a Precision-Recall Curve by taking the weighted mean of the
070     * precisions at a given threshold, where the weight is the recall achieved at
071     * that threshold.
072     *
073     * @see LabelEvaluationUtil#averagedPrecision(boolean[], double[])
074     *
075     * @param label The target label.
076     * @return The averaged precision for that label.
077     *
078     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
079     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
080     */
081    double averagedPrecision(Label label);
082
083    /**
084     * Calculates the Precision Recall curve for a single label.
085     *
086     * @see LabelEvaluationUtil#generatePRCurve(boolean[], double[])
087     *
088     * @param label The target label.
089     * @return The precision recall curve for that label.
090     *
091     * @implSpec Implementations of this class are expected to throw {@link UnsupportedOperationException} if the model
092     * corresponding to this evaluation does not generate probabilities, which are required to compute the ROC curve.
093     */
094    LabelEvaluationUtil.PRCurve precisionRecallCurve(Label label);
095
096    /**
097     * Returns a HTML formatted String representing this evaluation.
098     * <p>
099     * Uses the label order of the confusion matrix, which can be used to display
100     * a subset of the per label metrics. When they are subset the total row
101     * represents only the subset selected, not all the predictions, however
102     * the accuracy and averaged metrics cover all the predictions.
103     * @return A HTML formatted String.
104     */
105    default String toHTML() {
106        return LabelEvaluation.toHTML(this);
107    }
108
109    /**
110     * This method produces a nicely formatted String output, with
111     * appropriate tabs and newlines, suitable for display on a terminal.
112     * It can be used as an implementation of the {@link EvaluationRenderer}
113     * functional interface.
114     * <p>
115     * Uses the label order of the confusion matrix, which can be used to display
116     * a subset of the per label metrics. When they are subset the total row
117     * represents only the subset selected, not all the predictions, however
118     * the accuracy and averaged metrics cover all the predictions.
119     * @param evaluation The evaluation to format.
120     * @return Formatted output showing the main results of the evaluation.
121     */
122    public static String toFormattedString(LabelEvaluation evaluation) {
123        ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
124        List<Label> labelOrder = new ArrayList<>(cm.getLabelOrder());
125        labelOrder.retainAll(cm.observed());
126        StringBuilder sb = new StringBuilder();
127        int tp = 0;
128        int fn = 0;
129        int fp = 0;
130        int n = 0;
131        //
132        // Figure out the biggest class label and therefore the format string
133        // that we should use for them.
134        int maxLabelSize = "Balanced Error Rate".length();
135        for(Label label : labelOrder) {
136            maxLabelSize = Math.max(maxLabelSize, label.getLabel().length());
137        }
138        String labelFormatString = String.format("%%-%ds", maxLabelSize+2);
139        sb.append(String.format(labelFormatString, "Class"));
140        sb.append(String.format("%12s%12s%12s%12s", "n", "tp", "fn", "fp"));
141        sb.append(String.format("%12s%12s%12s%n", "recall", "prec", "f1"));
142        for (Label label : labelOrder) {
143            if (cm.support(label) == 0) {
144                continue;
145            }
146            n += cm.support(label);
147            tp += cm.tp(label);
148            fn += cm.fn(label);
149            fp += cm.fp(label);
150            sb.append(String.format(labelFormatString, label));
151            sb.append(String.format("%,12d%,12d%,12d%,12d",
152                    (int) cm.support(label),
153                    (int) cm.tp(label),
154                    (int) cm.fn(label),
155                    (int) cm.fp(label)
156            ));
157            sb.append(String.format("%12.3f%12.3f%12.3f%n",
158                    evaluation.recall(label),
159                    evaluation.precision(label),
160                    evaluation.f1(label)));
161        }
162        sb.append(String.format(labelFormatString, "Total"));
163        sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
164        sb.append(String.format(labelFormatString, "Accuracy"));
165        sb.append(String.format("%60.3f%n", evaluation.accuracy()));
166        sb.append(String.format(labelFormatString, "Micro Average"));
167        sb.append(String.format("%60.3f%12.3f%12.3f%n",
168                evaluation.microAveragedRecall(),
169                evaluation.microAveragedPrecision(),
170                evaluation.microAveragedF1()));
171        sb.append(String.format(labelFormatString, "Macro Average"));
172        sb.append(String.format("%60.3f%12.3f%12.3f%n",
173                evaluation.macroAveragedRecall(),
174                evaluation.macroAveragedPrecision(),
175                evaluation.macroAveragedF1()));
176        sb.append(String.format(labelFormatString, "Balanced Error Rate"));
177        sb.append(String.format("%60.3f", evaluation.balancedErrorRate()));
178        return sb.toString();
179    }
180
181    /**
182     * This method produces a HTML formatted String output, with
183     * appropriate tabs and newlines, suitable for integration into a webpage.
184     * It can be used as an implementation of the {@link EvaluationRenderer}
185     * functional interface.
186     * <p>
187     * Uses the label order of the confusion matrix, which can be used to display
188     * a subset of the per label metrics. When they are subset the total row
189     * represents only the subset selected, not all the predictions, however
190     * the accuracy and averaged metrics cover all the predictions.
191     * @param evaluation The evaluation to format.
192     * @return Formatted HTML output showing the main results of the evaluation.
193     */
194    public static String toHTML(LabelEvaluation evaluation) {
195        ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
196        List<Label> labelOrder = cm.getLabelOrder();
197        StringBuilder sb = new StringBuilder();
198        int tp = 0;
199        int fn = 0;
200        int fp = 0;
201        int tn = 0;
202        sb.append("<table>\n");
203        sb.append("<tr>\n");
204        sb.append("<th>Class</th><th>n</th> <th>%</th> <th>tp</th> <th>fn</th> <th>fp</th> <th>Recall</th> <th>Precision</th> <th>F1</th>");
205        sb.append("\n</tr>\n");
206        //
207        // Compute the total number of instances first, so we can show proportions.
208        for (Label label : labelOrder) {
209            //tn += occurrences.getOrDefault(label, 0);
210            tn += cm.tn(label);
211        }
212        for (Label label : labelOrder) {
213            if (cm.support(label) == 0) {
214                continue;
215            }
216            tp += cm.tp(label);
217            fn += cm.fn(label);
218            fp += cm.fp(label);
219            sb.append("<tr>");
220            sb.append("<td><code>").append(label).append("</code></td>");
221            int occurrence = (int) cm.support(label);
222            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", occurrence)).append("</td>");
223            sb.append("<td style=\"text-align:right\">").append(String.format("%8.1f%%", (occurrence/ (double) tn)*100)).append("</td>");
224            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.tp(label))).append("</td>");
225            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fn(label))).append("</td>");
226            sb.append("<td style=\"text-align:right\">").append(String.format("%,d", (int) cm.fp(label))).append("</td>");
227            sb.append(String.format("<td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n",
228                    evaluation.recall(label), evaluation.precision(label), evaluation.f1(label)));
229            sb.append("</tr>");
230        }
231        sb.append("<tr>");
232        sb.append("<td>Total</td>");
233        sb.append(String.format("<td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\"></td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td><td style=\"text-align:right\">%,12d</td>%n", tn, tp, fn, fp));
234        sb.append("<td colspan=\"4\"></td>");
235        sb.append("</tr>\n<tr>");
236        sb.append(String.format("<td>Accuracy</td><td style=\"text-align:right\" colspan=\"6\">%8.3f</td>%n", evaluation.accuracy()));
237        sb.append("<td colspan=\"4\"></td>");
238        sb.append("</tr>\n<tr>");
239        sb.append("<td>Micro Average</td>");
240        sb.append(String.format("<td style=\"text-align:right\" colspan=\"6\">%8.3f</td><td style=\"text-align:right\">%8.3f</td><td style=\"text-align:right\">%8.3f</td>%n",
241                evaluation.microAveragedRecall(),
242                evaluation.microAveragedPrecision(),
243                evaluation.microAveragedF1()));
244        sb.append("</tr></table>");
245        return sb.toString();
246    }
247
248}