package de.unidue.ltl.evaluation.visualization;

import de.tudarmstadt.ukp.dkpro.core.api.frequency.util.ConditionalFrequencyDistribution;
import de.tudarmstadt.ukp.dkpro.core.api.frequency.util.FrequencyDistribution;
import de.unidue.ltl.evaluation.core.AbstractConfusionMatrix;
import de.unidue.ltl.evaluation.core.EvaluationData;
import de.unidue.ltl.evaluation.core.EvaluationEntry;
import de.vandermeer.asciitable.v2.V2_AsciiTable;
import de.vandermeer.asciitable.v2.render.V2_AsciiTableRenderer;
import de.vandermeer.asciitable.v2.render.WidthLongestWord;
import de.vandermeer.asciitable.v2.themes.V2_E_TableThemes;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:de/unidue/ltl/evaluation/visualization/ConfusionMatrix.class */
public class ConfusionMatrix<T> extends AbstractConfusionMatrix<T> {
    Set<T> allLabels = new HashSet();
    ConditionalFrequencyDistribution<T, T> cfd = new ConditionalFrequencyDistribution<>();

    public ConfusionMatrix(EvaluationData<T> evaluationData) {
        Iterator it = evaluationData.iterator();
        while (it.hasNext()) {
            EvaluationEntry evaluationEntry = (EvaluationEntry) it.next();
            Object gold = evaluationEntry.getGold();
            Object predicted = evaluationEntry.getPredicted();
            this.cfd.addSample(gold, predicted, 1L);
            this.allLabels.add(gold);
            this.allLabels.add(predicted);
        }
    }

    public String toText() {
        List<T> labels = getLabels();
        int[][] twoDimensionalArray = getTwoDimensionalArray();
        V2_AsciiTable v2_AsciiTable = new V2_AsciiTable();
        v2_AsciiTable.addStrongRule();
        v2_AsciiTable.addRow(getTableHeader("Predicted"));
        v2_AsciiTable.addRow(getLabelHeader(labels));
        v2_AsciiTable.addStrongRule();
        for (int i = 0; i < labels.size(); i++) {
            T t = labels.get(i);
            Object[] objArr = new Object[labels.size() + 1];
            objArr[0] = t;
            for (int i2 = 0; i2 < labels.size(); i2++) {
                objArr[i2 + 1] = Integer.valueOf(twoDimensionalArray[i][i2]);
            }
            v2_AsciiTable.addRow(objArr);
            v2_AsciiTable.addRule();
        }
        v2_AsciiTable.addStrongRule();
        V2_AsciiTableRenderer v2_AsciiTableRenderer = new V2_AsciiTableRenderer();
        v2_AsciiTableRenderer.setTheme(V2_E_TableThemes.NO_BORDERS.get());
        v2_AsciiTableRenderer.setWidth(new WidthLongestWord());
        return v2_AsciiTableRenderer.render(v2_AsciiTable).toString();
    }

    private Object[] getTableHeader(String str) {
        Object[] objArr = new Object[this.allLabels.size() + 1];
        for (int i = 0; i < objArr.length; i++) {
            objArr[i] = null;
        }
        objArr[0] = "";
        objArr[objArr.length - 1] = str;
        return objArr;
    }

    private Object[] getLabelHeader(List<T> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.add("");
        arrayList.addAll(list);
        return arrayList.toArray();
    }

    public long getNumberOfEntries() {
        return this.cfd.getN();
    }

    public long getNumberOfConfusions(T t, T t2) {
        return this.cfd.getCount(t, t2);
    }

    public long getTruePositives(T t) {
        return this.cfd.getCount(t, t);
    }

    public long getFalseNegatives(T t) {
        FrequencyDistribution frequencyDistribution = this.cfd.getFrequencyDistribution(t);
        long j = 0;
        if (frequencyDistribution == null) {
            return 0L;
        }
        for (Object obj : frequencyDistribution.getKeys()) {
            if (!obj.equals(t)) {
                j += frequencyDistribution.getCount(obj);
            }
        }
        return j;
    }

    public long getFalsePositives(T t) {
        long j = 0;
        for (Object obj : this.cfd.getConditions()) {
            if (!obj.equals(t)) {
                j += this.cfd.getFrequencyDistribution(obj).getCount(t);
            }
        }
        return j;
    }

    public long getTrueNegatives(T t) {
        long j = 0;
        for (Object obj : this.cfd.getConditions()) {
            if (!obj.equals(t)) {
                FrequencyDistribution frequencyDistribution = this.cfd.getFrequencyDistribution(obj);
                for (Object obj2 : frequencyDistribution.getKeys()) {
                    if (!obj2.equals(t)) {
                        j += frequencyDistribution.getCount(obj2);
                    }
                }
            }
        }
        return j;
    }

    public List<T> getLabels() {
        ArrayList arrayList = new ArrayList(this.allLabels);
        Collections.sort(arrayList, new Comparator<T>() { // from class: de.unidue.ltl.evaluation.visualization.ConfusionMatrix.1
            @Override // java.util.Comparator
            public int compare(T t, T t2) {
                if (t.equals(t2)) {
                    return 0;
                }
                return t.toString().compareTo(t2.toString());
            }
        });
        return arrayList;
    }

    public int[][] getTwoDimensionalArray() {
        List<T> labels = getLabels();
        int size = labels.size();
        int[][] iArr = new int[size][size];
        int i = 0;
        Iterator<T> it = labels.iterator();
        while (it.hasNext()) {
            FrequencyDistribution frequencyDistribution = this.cfd.getFrequencyDistribution(it.next());
            int i2 = 0;
            for (T t : labels) {
                if (frequencyDistribution != null) {
                    iArr[i][i2] = Long.valueOf(frequencyDistribution.getCount(t)).intValue();
                } else {
                    iArr[i][i2] = 0;
                }
                i2++;
            }
            i++;
        }
        return iArr;
    }
}
