001/*
002 * Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.evaluation;
018
019import org.tribuo.ImmutableOutputInfo;
020import org.tribuo.classification.Classifiable;
021
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.List;
025import java.util.Set;
026import java.util.function.ToDoubleFunction;
027
028/**
029 * A confusion matrix for {@link Classifiable}s.
030 *
031 * <p>
032 * We interpret it as follows:
033 *
034 * {@code
035 * C[i, j] = k
036 * }
037 *
038 * means "the TRUE class 'j' was PREDICTED to be class 'i' a total of 'k' times".
039 *
040 * <p>
041 * In other words, the row indices correspond to the model's predictions, and the column indices correspond to
042 * the ground truth.
043 * </p>
044 * @param <T> The type of the output.
045 */
046public interface ConfusionMatrix<T extends Classifiable<T>> {
047
048    /**
049     * Returns the classification domain that this confusion matrix operates over.
050     * @return The classification domain.
051     */
052    public ImmutableOutputInfo<T> getDomain();
053
054    /**
055     * The number of examples this confusion matrix has seen.
056     * @return The number of examples.
057     */
058    public double support();
059
060    /**
061     * The number of examples with this true label this confusion matrix has seen.
062     * @param cls The label.
063     * @return The number of examples.
064     */
065    public double support(T cls);
066
067    /**
068     * The number of true positives for the supplied label.
069     * @param cls The label.
070     * @return The number of examples.
071     */
072    public double tp(T cls);
073
074    /**
075     * The number of false positives for the supplied label.
076     * @param cls The label.
077     * @return The number of examples.
078     */
079    public double fp(T cls);
080
081    /**
082     * The number of false negatives for the supplied label.
083     * @param cls The label.
084     * @return The number of examples.
085     */
086    public double fn(T cls);
087
088    /**
089     * The number of true negatives for the supplied label.
090     * @param cls The label.
091     * @return The number of examples.
092     */
093    public double tn(T cls);
094
095    /**
096     * The number of times the supplied predicted label was returned for the supplied true class.
097     * @param predictedLabel The predicted label.
098     * @param trueLabel The true label.
099     * @return The number of examples predicted as {@code predictedLabel} when the true label was {@code trueLabel}.
100     */
101    public double confusion(T predictedLabel, T trueLabel);
102
103    /**
104     * The total number of true positives.
105     * @return The total true positives.
106     */
107    public default double tp() {
108        return sumOverOutputs(getDomain(), this::tp);
109    }
110
111    /**
112     * The total number of false positives.
113     * @return The total false positives.
114     */
115    public default double fp() {
116        return sumOverOutputs(getDomain(), this::fp);
117    }
118
119    /**
120     * The total number of false negatives.
121     * @return The total false negatives.
122     */
123    public default double fn() {
124        return sumOverOutputs(getDomain(), this::fn);
125    }
126
127    /**
128     * The total number of true negatives.
129     * @return The total true negatives.
130     */
131    public default double tn() {
132        return sumOverOutputs(getDomain(), this::tn);
133    }
134
135    /**
136     * The values this confusion matrix has seen.
137     * <p>
138     * The default implementation is provided for compatibility reasons and will be removed
139     * in a future major release. It defaults to returning the output domain.
140     * @return The set of observed outputs.
141     */
142    default public Set<T> observed() {
143        return getDomain().getDomain();
144    }
145
146    /**
147     * The label order this confusion matrix uses in {@code toString}.
148     * <p>
149     * The default implementation is provided for compatibility reasons and will be removed
150     * in a future major release. It defaults to the output domain iterated in hash order.
151     * @return An unmodifiable view on the label order.
152     */
153    public default List<T> getLabelOrder() {
154        return Collections.unmodifiableList(new ArrayList<>(getDomain().getDomain()));
155    }
156
157    /**
158     * Sets the label order this confusion matrix uses in {@code toString}.
159     * <p>
160     * If the label order is a subset of the labels in the domain, only the
161     * labels present in the label order will be displayed.
162     * <p>
163     * The default implementation does not set the label order and is provided for
164     * backwards compatibility reasons. It should be overridden in all subclasses to
165     * ensure correct behaviour, and this default implementation will be removed in a
166     * future major release.
167     * @param labelOrder The label order.
168     */
169    public default void setLabelOrder(List<T> labelOrder) {}
170
171    /**
172     * Sums the supplied getter over the domain.
173     * @param domain The domain to sum over.
174     * @param getter The getter to use.
175     * @param <T> The type of the output.
176     * @return The total summed over the domain.
177     */
178    static <T extends Classifiable<T>> double sumOverOutputs(ImmutableOutputInfo<T> domain, ToDoubleFunction<T> getter) {
179        double total = 0;
180        for (T key : domain.getDomain()) {
181            total += getter.applyAsDouble(key);
182        }
183        return total;
184    }
185
186}