package org.fbk.cit.hlt.core.mylibsvm;

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.LineNumberReader;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;
import org.apache.xerces.impl.xs.SchemaSymbols;
import org.fbk.cit.hlt.thewikimachine.util.StringTable;

/* loaded from: input_file:org/fbk/cit/hlt/core/mylibsvm/OVA.class */
public class OVA {
    static Logger logger = Logger.getLogger(OVA.class.getName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/fbk/cit/hlt/core/mylibsvm/OVA$Pair.class */
    public class Pair {
        String c;
        String e;

        public Pair(String str, String str2) {
            this.c = str;
            this.e = str2;
        }
    }

    public OVA(File file, File file2, String str, double d) throws IOException {
        run(file, file2, str, d);
    }

    public OVA(File file, int i, String str, double d) throws IOException {
        DecimalFormat decimalFormat = new DecimalFormat("0.00");
        double[] dArr = new double[4];
        List<Pair> readDataset = readDataset(file);
        for (int i2 = 0; i2 < i; i2++) {
            logger.info("CROSS VALIDATION START " + i2 + "/" + i);
            File file2 = new File(file.getAbsolutePath() + "_train_" + i);
            File file3 = new File(file.getAbsolutePath() + "_test_" + i);
            split(readDataset, file2, file3, i2, i);
            double[] run = run(file2, file3, str, d);
            for (int i3 = 0; i3 < run.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + run[i3];
            }
            logger.info("CROSS VALIDATION END " + i2 + "/" + i);
        }
        double d2 = dArr[0] / (dArr[0] + dArr[1]);
        double d3 = dArr[0] / (dArr[0] + dArr[2]);
        logger.info("tp\tfp\tfn\tsize\tp\tr\tf1");
        logger.info(((int) dArr[0]) + StringTable.HORIZONTAL_TABULATION + ((int) dArr[1]) + StringTable.HORIZONTAL_TABULATION + ((int) dArr[2]) + StringTable.HORIZONTAL_TABULATION + ((int) dArr[3]) + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(d2) + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(d3) + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(((2.0d * d2) * d3) / (d2 + d3)));
    }

    private void split(List<Pair> list, File file, File file2, int i, int i2) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file2));
        PrintWriter printWriter2 = new PrintWriter(new FileWriter(file));
        for (int i3 = 0; i3 < list.size(); i3++) {
            if ((i3 + i) % i2 == 0) {
                printWriter.println(list.get(i3).c + " " + list.get(i3).e);
            } else {
                printWriter2.println(list.get(i3).c + " " + list.get(i3).e);
            }
        }
        printWriter.close();
        printWriter2.close();
    }

    public double[] run(File file, File file2, String str, double d) throws IOException {
        List<Pair> readDataset = readDataset(file);
        List<Pair> readDataset2 = readDataset(file2);
        Set<String> classes = classes(readDataset);
        String[] strArr = new String[classes.size()];
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (String str2 : classes) {
            strArr[i] = str2;
            logger.info(str2);
            String str3 = str + "_train_" + str2;
            String str4 = str + "_test_" + str2;
            String str5 = str + "_mdl_" + str2;
            String str6 = str + "_out_" + str2;
            writeProblem(readDataset, str2, str3);
            writeProblem(readDataset2, str2, str4);
            new svm_train();
            logger.info("class " + str2 + " => " + r0[5]);
            String[] strArr2 = {"-t", SchemaSymbols.ATTVAL_FALSE_0, "-m", "2000", "-w1", new Double(d).toString(), str3, str5};
            svm_train.main(strArr2);
            new svm_multiclass_predict();
            svm_multiclass_predict.main(new String[]{str4, str5, str6});
            arrayList.add(readOutput(new File(str6)));
            i++;
        }
        PrintWriter printWriter = new PrintWriter(new FileWriter(str + "_result"));
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < ((List) arrayList.get(0)).size(); i2++) {
            int i3 = -1;
            double d2 = 0.0d;
            for (int i4 = 0; i4 < strArr.length; i4++) {
                if (((Double[]) ((List) arrayList.get(i4)).get(i2))[0].doubleValue() > 0.0d && Math.abs(((Double[]) ((List) arrayList.get(i4)).get(i2))[1].doubleValue()) > d2) {
                    i3 = i4;
                    d2 = ((Double[]) ((List) arrayList.get(i4)).get(i2))[1].doubleValue();
                }
            }
            if (i3 > -1) {
                printWriter.println(strArr[i3] + StringTable.HORIZONTAL_TABULATION + d2);
                arrayList2.add(new Double(strArr[i3]));
            } else {
                printWriter.println("0\t0");
                arrayList2.add(Double.valueOf(0.0d));
            }
        }
        printWriter.close();
        return eval(arrayList2, readDataset2);
    }

    private double[] eval(List<Double> list, List<Pair> list2) {
        DecimalFormat decimalFormat = new DecimalFormat("0.00");
        int i = 0;
        int i2 = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        double d8 = 0.0d;
        double d9 = 0.0d;
        logger.info("eval " + list.size() + ", " + list2.size());
        for (int i3 = 0; i3 < list.size(); i3++) {
            double parseDouble = Double.parseDouble(list2.get(i3).c);
            double doubleValue = list.get(i3).doubleValue();
            if (doubleValue == parseDouble) {
                i++;
            }
            if (doubleValue == 0.0d) {
                d3 += 1.0d;
            } else if (doubleValue == parseDouble) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
                d3 += 1.0d;
            }
            d4 += (doubleValue - parseDouble) * (doubleValue - parseDouble);
            d5 += doubleValue;
            d6 += parseDouble;
            d7 += doubleValue * doubleValue;
            d8 += parseDouble * parseDouble;
            d9 += doubleValue * parseDouble;
            i2++;
        }
        logger.info("Accuracy = " + ((i / i2) * 100.0d) + "% (" + i + "/" + i2 + ") (classification)\n");
        double d10 = d / (d + d2);
        double d11 = d / (d + d3);
        logger.info("===\ntp\tfp\tfn\tsize\tp\tr\tf1");
        logger.info(((int) d) + StringTable.HORIZONTAL_TABULATION + ((int) d2) + StringTable.HORIZONTAL_TABULATION + ((int) d3) + StringTable.HORIZONTAL_TABULATION + list.size() + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(d10) + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(d11) + StringTable.HORIZONTAL_TABULATION + decimalFormat.format(((2.0d * d10) * d11) / (d10 + d11)) + "\n===");
        return new double[]{d, d2, d3, i2};
    }

    private List<Double[]> readOutput(File file) throws IOException {
        ArrayList arrayList = new ArrayList();
        LineNumberReader lineNumberReader = new LineNumberReader(new FileReader(file));
        while (true) {
            String readLine = lineNumberReader.readLine();
            if (readLine == null) {
                return arrayList;
            }
            String[] split = readLine.split(StringTable.HORIZONTAL_TABULATION);
            arrayList.add(new Double[]{Double.valueOf(Double.parseDouble(split[0])), Double.valueOf(Double.parseDouble(split[1]))});
        }
    }

    private void writeProblem(List<Pair> list, String str, String str2) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(str2));
        for (int i = 0; i < list.size(); i++) {
            if (str.equals(list.get(i).c)) {
                printWriter.print(SchemaSymbols.ATTVAL_TRUE_1);
            } else {
                printWriter.print(SchemaSymbols.ATTVAL_FALSE_0);
            }
            printWriter.println(" " + list.get(i).e);
        }
        printWriter.close();
    }

    private Set<String> classes(List<Pair> list) {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < list.size(); i++) {
            hashSet.add(list.get(i).c);
        }
        return hashSet;
    }

    private List<Pair> readDataset(File file) throws IOException {
        ArrayList arrayList = new ArrayList();
        LineNumberReader lineNumberReader = new LineNumberReader(new FileReader(file));
        while (true) {
            String readLine = lineNumberReader.readLine();
            if (readLine == null) {
                return arrayList;
            }
            int indexOf = readLine.indexOf(" ");
            arrayList.add(new Pair(readLine.substring(0, indexOf), readLine.substring(indexOf + 1, readLine.length())));
        }
    }

    public static void main(String[] strArr) throws Exception {
        String property = System.getProperty("log-config");
        if (property == null) {
            property = "log-config.txt";
        }
        PropertyConfigurator.configure(property);
        new OVA(new File(strArr[0]), new File(strArr[1]), strArr[2], Double.parseDouble(strArr[3]));
    }
}
