package com.github.chen0040.svmext.evaluators;

import com.github.chen0040.svmext.utils.NumberUtils;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/github/chen0040/svmext/evaluators/ClassifierEvaluator.class */
public class ClassifierEvaluator {
    private ConfusionMatrix confusionMatrix = new ConfusionMatrix();

    public void evaluate(String str, String str2) {
        this.confusionMatrix.incCount(str, str2);
    }

    public List<String> classLabels() {
        return this.confusionMatrix.getLabels();
    }

    public void reset() {
        this.confusionMatrix.reset();
    }

    public ConfusionMatrix getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public void setConfusionMatrix(ConfusionMatrix confusionMatrix) {
        this.confusionMatrix = confusionMatrix;
    }

    public double getAccuracy() {
        double d = 0.0d;
        List<String> labels = this.confusionMatrix.getLabels();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        while (i3 < labels.size()) {
            String str = labels.get(i3);
            int i4 = 0;
            while (i4 < labels.size()) {
                int count = this.confusionMatrix.getCount(str, labels.get(i4));
                i += i3 == i4 ? count : 0;
                i2 += count;
                i4++;
            }
            i3++;
        }
        if (i2 > 0) {
            d = i / i2;
        }
        return d;
    }

    public double getMisclassificationRate() {
        return 1.0d - getAccuracy();
    }

    public int getTruePositiveCount(String str) {
        return this.confusionMatrix.getCount(str, str);
    }

    public int getFalsePositiveCount(String str) {
        return this.confusionMatrix.getColumnSum(str) - getTruePositiveCount(str);
    }

    public double avgTruePositive() {
        List<String> classLabels = classLabels();
        if (classLabels.isEmpty()) {
            return 0.0d;
        }
        int i = 0;
        Iterator<String> it = classLabels.iterator();
        while (it.hasNext()) {
            i += getTruePositiveCount(it.next());
        }
        return i / classLabels.size();
    }

    public double avgFalsePositive() {
        List<String> classLabels = classLabels();
        if (classLabels.isEmpty()) {
            return 0.0d;
        }
        int i = 0;
        Iterator<String> it = classLabels.iterator();
        while (it.hasNext()) {
            i += getFalsePositiveCount(it.next());
        }
        return i / classLabels.size();
    }

    public Map<String, Double> getPrecisionByClass() {
        HashMap hashMap = new HashMap();
        List<String> classLabels = classLabels();
        for (int i = 0; i < classLabels.size(); i++) {
            String str = classLabels.get(i);
            int count = this.confusionMatrix.getCount(str, str);
            int columnSum = this.confusionMatrix.getColumnSum(str);
            double d = 0.0d;
            if (columnSum > 0) {
                d = count / columnSum;
            }
            hashMap.put(str, Double.valueOf(d));
        }
        return hashMap;
    }

    public Map<String, Double> getRecallByClass() {
        HashMap hashMap = new HashMap();
        List<String> classLabels = classLabels();
        for (int i = 0; i < classLabels.size(); i++) {
            String str = classLabels.get(i);
            int count = this.confusionMatrix.getCount(str, str);
            int rowSum = this.confusionMatrix.getRowSum(str);
            double d = 0.0d;
            if (rowSum > 0) {
                d = count / rowSum;
            }
            hashMap.put(str, Double.valueOf(d));
        }
        return hashMap;
    }

    public Map<String, Double> getFalloutByClass() {
        HashMap hashMap = new HashMap();
        List<String> classLabels = classLabels();
        for (int i = 0; i < classLabels.size(); i++) {
            String str = classLabels.get(i);
            int i2 = 0;
            int i3 = 0;
            for (int i4 = 0; i4 < classLabels.size(); i4++) {
                if (i != i4) {
                    String str2 = classLabels.get(i4);
                    i3 += this.confusionMatrix.getCount(str2, str);
                    i2 += this.confusionMatrix.getRowSum(str2);
                }
            }
            double d = 0.0d;
            if (i2 > 0) {
                d = i3 / i2;
            }
            hashMap.put(str, Double.valueOf(d));
        }
        return hashMap;
    }

    public Map<String, Double> getF1ScoreByClass() {
        Map<String, Double> precisionByClass = getPrecisionByClass();
        Map<String, Double> recallByClass = getRecallByClass();
        List<String> classLabels = classLabels();
        HashMap hashMap = new HashMap();
        for (String str : classLabels) {
            double doubleValue = precisionByClass.get(str).doubleValue();
            double doubleValue2 = recallByClass.get(str).doubleValue();
            if (!NumberUtils.isZero(Double.valueOf(doubleValue + doubleValue2))) {
                hashMap.put(str, Double.valueOf((2.0d * (doubleValue * doubleValue2)) / (doubleValue + doubleValue2)));
            }
        }
        return hashMap;
    }

    public double getMacroF1Score() {
        double d = 0.0d;
        int i = 0;
        Iterator<Map.Entry<String, Double>> it = getF1ScoreByClass().entrySet().iterator();
        while (it.hasNext()) {
            d += it.next().getValue().doubleValue();
            i++;
        }
        if (i == 0) {
            return 0.0d;
        }
        return d / i;
    }

    public double getMicroF1Score() {
        Map<String, Double> precisionByClass = getPrecisionByClass();
        Map<String, Double> recallByClass = getRecallByClass();
        double d = 0.0d;
        double d2 = 0.0d;
        for (String str : classLabels()) {
            d += precisionByClass.get(str).doubleValue();
            d2 += recallByClass.get(str).doubleValue();
        }
        double size = d / r0.size();
        double size2 = d2 / r0.size();
        return (2.0d * (size * size2)) / (size + size2);
    }

    public String getSummary() {
        StringBuilder sb = new StringBuilder();
        sb.append("accuracy: ").append(getAccuracy());
        sb.append("\nmis-classification: ").append(getMisclassificationRate());
        sb.append("\nmacro f1-score: ").append(getMacroF1Score());
        sb.append("\nmicro f1-score: ").append(getMicroF1Score());
        return sb.toString();
    }

    public void report() {
        System.out.println(getSummary());
    }
}
