package org.campagnelab.dl.framework.performance;

import it.unimi.dsi.fastutil.objects.Object2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.objects.Object2ObjectMap;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:org/campagnelab/dl/framework/performance/PerformanceLogger.class */
public class PerformanceLogger {
    private static final String perfFilenameFormat = "%s-perf-log.tsv";
    private String directory;
    private String conditionId;
    private double[] bestPerformances;
    private boolean[] performanceLargeIsBest;
    private String[] performanceNames;
    private double bestScore = 3.4028234663852886E38d;
    private double bestAUC = -1.0d;
    private Object2ObjectMap<String, List<Performance>> log = new Object2ObjectAVLTreeMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/campagnelab/dl/framework/performance/PerformanceLogger$Performance.class */
    public class Performance {
        long numExamplesUsed;
        int epoch;
        double[] performanceValues;
        String[] performanceNames;
        public double trainingScore;
        double score;
        double auc;

        public Performance(long j, int i, double d, double d2) {
            this.numExamplesUsed = j;
            this.epoch = i;
            this.score = d;
            this.auc = d2;
        }

        public Performance(long j, int i, String[] strArr, double... dArr) {
            this.numExamplesUsed = j;
            this.epoch = i;
            this.performanceValues = (double[]) dArr.clone();
            this.performanceNames = strArr;
        }

        public String formatValues() {
            if (this.performanceNames == null) {
                return String.format("%f\t%f", Double.valueOf(this.score), Double.valueOf(this.auc));
            }
            String str = "";
            int i = 0;
            for (double d : this.performanceValues) {
                str = str + String.format("%f", Double.valueOf(d));
                i++;
                if (i < this.performanceValues.length) {
                    str = str + "\t";
                }
            }
            return str;
        }
    }

    public void definePerformances(Metric... metricArr) {
        this.performanceLargeIsBest = new boolean[metricArr.length];
        this.performanceNames = new String[metricArr.length];
        this.bestPerformances = new double[metricArr.length];
        int i = 0;
        for (Metric metric : metricArr) {
            this.performanceNames[i] = metric.name;
            this.performanceLargeIsBest[i] = metric.largerIsBetter;
            this.bestPerformances[i] = this.performanceLargeIsBest[i] ? Double.NEGATIVE_INFINITY : Double.MAX_VALUE;
            i++;
        }
    }

    public double getBestScore() {
        int i = 0;
        for (String str : this.performanceNames) {
            if ("score".equals(str)) {
                return this.bestPerformances[i];
            }
            i++;
        }
        return Double.NaN;
    }

    public double getBestAUC() {
        return this.bestAUC;
    }

    public double getBest(String str) {
        int i = 0;
        for (String str2 : this.performanceNames) {
            if (str2.equals(str)) {
                return this.bestPerformances[i];
            }
            i++;
        }
        throw new IllegalArgumentException("The metric name was not defined: " + str);
    }

    public PerformanceLogger(String str) {
        this.directory = str;
        File file = new File(str);
        if (!file.exists() && !file.mkdirs()) {
            throw new IllegalArgumentException("Unable to create log directory at " + str + " current directory is " + new File(".").getAbsolutePath());
        }
    }

    public void setCondition(String str) {
        this.conditionId = str;
    }

    public void clear() {
        this.log.clear();
    }

    public void logMetrics(String str, long j, int i, double... dArr) {
        ObjectArrayList objectArrayList = new ObjectArrayList();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] == dArr[i2]) {
                if (this.performanceLargeIsBest[i2]) {
                    this.bestPerformances[i2] = Math.max(this.bestPerformances[i2], dArr[i2]);
                } else {
                    this.bestPerformances[i2] = Math.min(this.bestPerformances[i2], dArr[i2]);
                }
            }
        }
        ((List) this.log.getOrDefault(str, objectArrayList)).add(new Performance(j, i, this.performanceNames, dArr));
        if (objectArrayList.size() > 0) {
            this.log.put(str, objectArrayList);
        }
    }

    public void logTrainingScore(String str, int i, double d) {
        ((Performance) ((List) this.log.get(str)).stream().filter(performance -> {
            return performance.epoch == i;
        }).findFirst().get()).trainingScore = d;
    }

    public void show(String str) {
        List list = (List) this.log.get(str);
        Performance performance = (Performance) list.get(list.size() - 1);
        System.out.printf("%d\t%f\t%s%n", Integer.valueOf(performance.epoch), Double.valueOf(performance.trainingScore), performance.formatValues());
    }

    public void log(String str, long j, int i, double d, double d2) {
        ObjectArrayList objectArrayList = new ObjectArrayList();
        this.bestScore = Math.min(this.bestScore, d);
        this.bestAUC = Math.max(this.bestAUC, d2);
        ((List) this.log.getOrDefault(str, objectArrayList)).add(new Performance(j, i, d, d2));
        if (objectArrayList.size() > 0) {
            this.log.put(str, objectArrayList);
        }
    }

    public int getBestEpoch(String str) {
        int i = -1;
        List list = (List) this.log.get(str);
        if (list == null) {
            return -1;
        }
        Iterator it = list.iterator();
        while (it.hasNext()) {
            i = Math.max(i, ((Performance) it.next()).epoch);
        }
        return i;
    }

    public void log(String str, long j, int i, double d) {
        ((List) this.log.getOrDefault(str, new ObjectArrayList())).add(new Performance(j, i, d, -1.0d));
    }

    public void write() throws IOException {
        ObjectIterator it = this.log.keySet().iterator();
        while (it.hasNext()) {
            write((String) it.next());
        }
    }

    public void write(String str) throws IOException {
        FileWriter fileWriter = new FileWriter(this.directory + "/" + String.format(perfFilenameFormat, str));
        try {
            writeHeaders(fileWriter);
            List<Performance> list = (List) this.log.get(str);
            if (list == null) {
                return;
            }
            for (Performance performance : list) {
                fileWriter.write(String.format("%d\t%d\t%f\t%s", Long.valueOf(performance.numExamplesUsed), Integer.valueOf(performance.epoch), Double.valueOf(performance.trainingScore), performance.formatValues()));
                if (this.conditionId != null) {
                    fileWriter.write("\t" + this.conditionId);
                }
                fileWriter.write("\n");
            }
            fileWriter.flush();
            fileWriter.close();
        } finally {
            fileWriter.close();
        }
    }

    private void writeHeaders(Writer writer) throws IOException {
        writer.write("numExamplesUsed\tepoch\ttrainingScore\t" + getMetricHeader());
        if (this.conditionId != null) {
            writer.write("\tcondition");
        }
        writer.write("\n");
    }

    public String getMetricHeader() {
        if (this.performanceNames == null) {
            return "score\tAUC";
        }
        String str = "epoch\ttrainingScore\t";
        int i = 0;
        for (String str2 : this.performanceNames) {
            str = str + str2;
            i++;
            if (i < this.performanceNames.length) {
                str = str + "\t";
            }
        }
        return str;
    }
}
