001/* 002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.evaluation; 018 019import org.tribuo.ImmutableOutputInfo; 020import org.tribuo.classification.Classifiable; 021 022import java.util.ArrayList; 023import java.util.Collections; 024import java.util.List; 025import java.util.Set; 026import java.util.function.ToDoubleFunction; 027 028/** 029 * A confusion matrix for {@link Classifiable}s. 030 * 031 * <p> 032 * We interpret it as follows: 033 * 034 * {@code 035 * C[i, j] = k 036 * } 037 * 038 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times". 039 * 040 * <p> 041 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to 042 * the ground truth. 043 * </p> 044 * @param <T> The type of the output. 045 */ 046public interface ConfusionMatrix<T extends Classifiable<T>> { 047 048 /** 049 * Returns the classification domain that this confusion matrix operates over. 050 * @return The classification domain. 051 */ 052 public ImmutableOutputInfo<T> getDomain(); 053 054 /** 055 * The number of examples this confusion matrix has seen. 056 * @return The number of examples. 057 */ 058 public double support(); 059 060 /** 061 * The number of examples with this true label this confusion matrix has seen. 062 * @param cls The label. 063 * @return The number of examples. 064 */ 065 public double support(T cls); 066 067 /** 068 * The number of true positives for the supplied label. 069 * @param cls The label. 070 * @return The number of examples. 071 */ 072 public double tp(T cls); 073 074 /** 075 * The number of false positives for the supplied label. 076 * @param cls The label. 077 * @return The number of examples. 078 */ 079 public double fp(T cls); 080 081 /** 082 * The number of false negatives for the supplied label. 083 * @param cls The label. 084 * @return The number of examples. 085 */ 086 public double fn(T cls); 087 088 /** 089 * The number of true negatives for the supplied label. 090 * @param cls The label. 091 * @return The number of examples. 092 */ 093 public double tn(T cls); 094 095 /** 096 * The number of times the supplied predicted label was returned for the supplied true class. 097 * @param predictedLabel The predicted label. 098 * @param trueLabel The true label. 099 * @return The number of examples predicted as {@code predictedLabel} when the true label was {@code trueLabel}. 100 */ 101 public double confusion(T predictedLabel, T trueLabel); 102 103 /** 104 * The total number of true positives. 105 * @return The total true positives. 106 */ 107 public default double tp() { 108 return sumOverOutputs(getDomain(), this::tp); 109 } 110 111 /** 112 * The total number of false positives. 113 * @return The total false positives. 114 */ 115 public default double fp() { 116 return sumOverOutputs(getDomain(), this::fp); 117 } 118 119 /** 120 * The total number of false negatives. 121 * @return The total false negatives. 122 */ 123 public default double fn() { 124 return sumOverOutputs(getDomain(), this::fn); 125 } 126 127 /** 128 * The total number of true negatives. 129 * @return The total true negatives. 130 */ 131 public default double tn() { 132 return sumOverOutputs(getDomain(), this::tn); 133 } 134 135 /** 136 * The values this confusion matrix has seen. 137 * <p> 138 * The default implementation is provided for compatibility reasons and will be removed 139 * in a future major release. It defaults to returning the output domain. 140 * @return The set of observed outputs. 141 */ 142 default public Set<T> observed() { 143 return getDomain().getDomain(); 144 } 145 146 /** 147 * The label order this confusion matrix uses in {@code toString}. 148 * <p> 149 * The default implementation is provided for compatibility reasons and will be removed 150 * in a future major release. It defaults to the output domain iterated in hash order. 151 * @return An unmodifiable view on the label order. 152 */ 153 public default List<T> getLabelOrder() { 154 return Collections.unmodifiableList(new ArrayList<>(getDomain().getDomain())); 155 } 156 157 /** 158 * Sets the label order this confusion matrix uses in {@code toString}. 159 * <p> 160 * If the label order is a subset of the labels in the domain, only the 161 * labels present in the label order will be displayed. 162 * <p> 163 * The default implementation does not set the label order and is provided for 164 * backwards compatibility reasons. It should be overridden in all subclasses to 165 * ensure correct behaviour, and this default implementation will be removed in a 166 * future major release. 167 * @param labelOrder The label order. 168 */ 169 public default void setLabelOrder(List<T> labelOrder) {} 170 171 /** 172 * Sums the supplied getter over the domain. 173 * @param domain The domain to sum over. 174 * @param getter The getter to use. 175 * @param <T> The type of the output. 176 * @return The total summed over the domain. 177 */ 178 static <T extends Classifiable<T>> double sumOverOutputs(ImmutableOutputInfo<T> domain, ToDoubleFunction<T> getter) { 179 double total = 0; 180 for (T key : domain.getDomain()) { 181 total += getter.applyAsDouble(key); 182 } 183 return total; 184 } 185 186}