001/*
002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.ImmutableOutputInfo;
020import org.tribuo.Model;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.math.la.DenseMatrix;
024
025import java.util.ArrayList;
026import java.util.Collections;
027import java.util.HashMap;
028import java.util.HashSet;
029import java.util.LinkedHashSet;
030import java.util.List;
031import java.util.Map;
032import java.util.Set;
033import java.util.function.ToDoubleFunction;
034import java.util.logging.Logger;
035
036/**
037 * A confusion matrix for {@link Label}s.
038 * <p>
039 * We interpret it as follows:
040 *
041 * {@code
042 * C[i, j] = k
043 * }
044 *
045 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times".
046 *
047 * <p>
048 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to
049 * the ground truth.
050 * </p>
051 */
052public final class LabelConfusionMatrix implements ConfusionMatrix<Label> {
053
054    private static final Logger logger = Logger.getLogger(LabelConfusionMatrix.class.getName());
055
056    private final ImmutableOutputInfo<Label> domain;
057
058    private final int total;
059    private final Map<Label, Double> occurrences;
060
061    private final Set<Label> observed;
062
063    private final DenseMatrix cm;
064
065    private List<Label> labelOrder;
066
067    /**
068     * Creates a confusion matrix from the supplied predictions, using the label info
069     * from the supplied model.
070     *
071     * @param model       The model to use for the label information.
072     * @param predictions The predictions.
073     */
074    public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) {
075        this(model.getOutputIDInfo(), predictions);
076    }
077
078    /**
079     * Creates a confusion matrix from the supplied predictions and label info.
080     *
081     * @param domain      The label information.
082     * @param predictions The predictions.
083     * @throws IllegalArgumentException If the domain doesn't contain all the predictions.
084     */
085    public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) {
086        this.domain = domain;
087        this.total = predictions.size();
088        this.cm = new DenseMatrix(domain.size(), domain.size());
089        this.occurrences = new HashMap<>();
090        this.observed = new HashSet<>();
091        this.labelOrder = Collections.unmodifiableList(new ArrayList<>(domain.getDomain()));
092        tabulate(predictions);
093    }
094
095    /**
096     * Aggregate the predictions into this confusion matrix.
097     *
098     * @param predictions The predictions to aggregate.
099     */
100    private void tabulate(List<Prediction<Label>> predictions) {
101        predictions.forEach(prediction -> {
102            Label y = prediction.getExample().getOutput();
103            Label p = prediction.getOutput();
104            //
105            // Check that the ground truth label is valid
106            if (y.getLabel().equals(Label.UNKNOWN)) {
107                throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate.");
108            }
109            occurrences.merge(y, 1d, Double::sum);
110            observed.add(y);
111            observed.add(p);
112            int iy = getIDOrThrow(y);
113            int ip = getIDOrThrow(p);
114            cm.add(ip, iy, 1d);
115        });
116    }
117
118    @Override
119    public ImmutableOutputInfo<Label> getDomain() {
120        return domain;
121    }
122
123    @Override
124    public Set<Label> observed() {
125        return Collections.unmodifiableSet(observed);
126    }
127
128    @Override
129    public double support() {
130        return total;
131    }
132
133    @Override
134    public double support(Label label) {
135        return occurrences.getOrDefault(label, 0d);
136    }
137
138    @Override
139    public double tp(Label cls) {
140        return compute(cls, (i) -> cm.get(i, i));
141    }
142
143    @Override
144    public double fp(Label cls) {
145        // Row-wise sum less true positives
146        return compute(cls, i -> cm.rowSum(i) - cm.get(i, i));
147    }
148
149    @Override
150    public double fn(Label cls) {
151        // Column-wise sum less true positives
152        return compute(cls, i -> cm.columnSum(i) - cm.get(i, i));
153    }
154
155    @Override
156    public double tn(Label cls) {
157        int n = getDomain().size();
158        int i = getDomain().getID(cls);
159        double total = 0d;
160        for (int j = 0; j < n; j++) {
161            if (j == i) {
162                continue;
163            }
164            for (int k = 0; k < n; k++) {
165                if (k == i) {
166                    continue;
167                }
168                total += cm.get(j, k);
169            }
170        }
171        return total;
172    }
173
174    @Override
175    public double confusion(Label predicted, Label trueClass) {
176        int i = getDomain().getID(predicted);
177        int j = getDomain().getID(trueClass);
178        return cm.get(i, j);
179    }
180
181    /**
182     * A convenience method for extracting the appropriate label statistic.
183     *
184     * @param cls    The label to check.
185     * @param getter The get function which accepts a label id.
186     * @return The statistic for that label id.
187     */
188    private double compute(Label cls, ToDoubleFunction<Integer> getter) {
189        int i = getDomain().getID(cls);
190        if (i < 0) {
191            logger.fine("Unknown Label " + cls);
192            return 0d;
193        }
194        return getter.applyAsDouble(i);
195    }
196
197    /**
198     * Gets the id for the supplied label, or throws an {@link IllegalArgumentException} if it's
199     * an unknown label.
200     *
201     * @param key The label.
202     * @return The int id for that label.
203     */
204    private int getIDOrThrow(Label key) {
205        int id = domain.getID(key);
206        if (id < 0) {
207            throw new IllegalArgumentException("Unknown label: " + key);
208        }
209        return id;
210    }
211
212    /**
213     * Sets the label order used in {@link #toString}.
214     * <p>
215     * If the label order is a subset of the labels in the domain, only the
216     * labels present in the label order will be displayed.
217     *
218     * @param newLabelOrder The label order to use.
219     */
220    @Override
221    public void setLabelOrder(List<Label> newLabelOrder) {
222        if (newLabelOrder == null || newLabelOrder.isEmpty()) {
223            throw new IllegalArgumentException("Label order must be non-null and non-empty.");
224        }
225        this.labelOrder = Collections.unmodifiableList(new ArrayList<>(newLabelOrder));
226    }
227
228    /**
229     * Gets the current label order.
230     *
231     * May trigger order instantiation if the label order has not been set.
232     * @return The label order.
233     */
234    public List<Label> getLabelOrder() {
235        return labelOrder;
236    }
237
238    @Override
239    public String toString() {
240        List<Label> curOrder = new ArrayList<>(labelOrder);
241        curOrder.retainAll(observed);
242
243        int maxLen = Integer.MIN_VALUE;
244        for (Label label : curOrder) {
245            maxLen = Math.max(label.getLabel().length(), maxLen);
246            maxLen = Math.max(String.format(" %,d", (int)(double)occurrences.getOrDefault(label,0.0)).length(), maxLen);
247        }
248
249        StringBuilder sb = new StringBuilder();
250        String trueLabelFormat = String.format("%%-%ds", maxLen + 2);
251        String predictedLabelFormat = String.format("%%%ds", maxLen + 2);
252        String countFormat = String.format("%%,%dd", maxLen + 2);
253
254        //
255        // Empty spot in first row for labels on subsequent rows.
256        sb.append(String.format(trueLabelFormat, ""));
257
258        //
259        // Labels across the top for predicted.
260        for (Label predictedLabel : curOrder) {
261            sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel()));
262        }
263        sb.append('\n');
264
265        for (Label trueLabel : curOrder) {
266            sb.append(String.format(trueLabelFormat, trueLabel.getLabel()));
267            for (Label predictedLabel : curOrder) {
268                int confusion = (int) confusion(predictedLabel, trueLabel);
269                sb.append(String.format(countFormat, confusion));
270            }
271            sb.append('\n');
272        }
273        return sb.toString();
274    }
275
276    /**
277     * Emits a HTML table representation of the Confusion Matrix.
278     * @return The confusion matrix as a HTML table.
279     */
280    public String toHTML() {
281        Set<Label> labelsToPrint = new LinkedHashSet<>(labelOrder);
282        labelsToPrint.retainAll(observed);
283        StringBuilder sb = new StringBuilder();
284        sb.append("<table>\n");
285        sb.append(String.format("<tr><th>True Label</th><th style=\"text-align:center\" colspan=\"%d\">Predicted Labels</th></tr>%n", occurrences.size() + 1));
286        sb.append("<tr><th></th>");
287        for (Label predictedLabel : labelsToPrint) {
288            sb.append("<th style=\"text-align:right\">")
289                    .append(predictedLabel)
290                    .append("</th>");
291        }
292        sb.append("<th style=\"text-align:right\">Total</th>");
293        sb.append("</tr>\n");
294        for (Label trueLabel : labelsToPrint) {
295            sb.append("<tr><th>").append(trueLabel).append("</th>");
296            double count = occurrences.getOrDefault(trueLabel, 0d);
297            for (Label predictedLabel : labelsToPrint) {
298                double tlmc = confusion(predictedLabel,trueLabel);
299                double percent = (tlmc / count) * 100;
300                sb.append("<td style=\"text-align:right\">")
301                        .append(String.format("%,d (%.1f%%)", (int)tlmc, percent))
302                        .append("</td>");
303            }
304            sb.append("<td style=\"text-align:right\">").append(count).append("</td>");
305            sb.append("</tr>\n");
306        }
307        sb.append("</table>");
308        return sb.toString();
309    }
310}