package org.deeplearning4j.spark.stats;

import java.awt.Color;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.io.FilenameUtils;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.LengthUnit;
import org.deeplearning4j.ui.components.chart.ChartHistogram;
import org.deeplearning4j.ui.components.chart.ChartLine;
import org.deeplearning4j.ui.components.chart.ChartTimeline;
import org.deeplearning4j.ui.components.chart.style.StyleChart;
import org.deeplearning4j.ui.components.component.ComponentDiv;
import org.deeplearning4j.ui.components.component.style.StyleDiv;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.standalone.StaticPageUtil;
import scala.Tuple3;

/* loaded from: input_file:org/deeplearning4j/spark/stats/StatsUtils.class */
public class StatsUtils {
    public static final long DEFAULT_MAX_TIMELINE_SIZE_MS = 1200000;

    /* loaded from: input_file:org/deeplearning4j/spark/stats/StatsUtils$StartTimeComparator.class */
    public static class StartTimeComparator implements Comparator<EventStats> {
        @Override // java.util.Comparator
        public int compare(EventStats eventStats, EventStats eventStats2) {
            return Long.compare(eventStats.getStartTime(), eventStats2.getStartTime());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/stats/StatsUtils$TupleComparator.class */
    public static class TupleComparator implements Comparator<Tuple3<String, String, Long>> {
        private TupleComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Tuple3<String, String, Long> tuple3, Tuple3<String, String, Long> tuple32) {
            return ((String) tuple3._1()).equals(tuple32._1()) ? ((String) tuple3._2()).equals(tuple32._2()) ? Long.compare(((Long) tuple3._3()).longValue(), ((Long) tuple32._3()).longValue()) : ((String) tuple3._2()).compareTo((String) tuple32._2()) : ((String) tuple3._1()).compareTo((String) tuple32._1());
        }
    }

    private StatsUtils() {
    }

    public static void exportStats(List<EventStats> list, String str, String str2, String str3, SparkContext sparkContext) throws IOException {
        exportStats(list, FilenameUtils.concat(str, str2), str3, sparkContext);
    }

    public static void exportStats(List<EventStats> list, String str, String str2, SparkContext sparkContext) throws IOException {
        StringBuilder sb = new StringBuilder();
        boolean z = true;
        for (EventStats eventStats : list) {
            if (z) {
                sb.append(eventStats.getStringHeader(str2)).append("\n");
            }
            sb.append(eventStats.asString(str2)).append("\n");
            z = false;
        }
        SparkUtils.writeStringToFile(str, sb.toString(), sparkContext);
    }

    public static String getDurationAsString(List<EventStats> list, String str) {
        StringBuilder sb = new StringBuilder();
        int size = list.size();
        int i = 0;
        Iterator<EventStats> it = list.iterator();
        while (it.hasNext()) {
            sb.append(it.next().getDurationMs());
            int i2 = i;
            i++;
            if (i2 < size - 1) {
                sb.append(str);
            }
        }
        return sb.toString();
    }

    public static void exportStatsAsHtml(SparkTrainingStats sparkTrainingStats, String str, JavaSparkContext javaSparkContext) throws Exception {
        exportStatsAsHtml(sparkTrainingStats, str, javaSparkContext.sc());
    }

    public static void exportStatsAsHtml(SparkTrainingStats sparkTrainingStats, String str, SparkContext sparkContext) throws Exception {
        exportStatsAsHtml(sparkTrainingStats, DEFAULT_MAX_TIMELINE_SIZE_MS, str, sparkContext);
    }

    public static void exportStatsAsHtml(SparkTrainingStats sparkTrainingStats, long j, String str, SparkContext sparkContext) throws Exception {
        Set<String> keySet = sparkTrainingStats.getKeySet();
        ArrayList arrayList = new ArrayList();
        StyleChart build = new StyleChart.Builder().backgroundColor(Color.WHITE).width(700.0d, LengthUnit.Px).height(400.0d, LengthUnit.Px).build();
        arrayList.add(new ComponentDiv(new StyleDiv.Builder().height(40.0d, LengthUnit.Px).width(100.0d, LengthUnit.Percent).build(), new Component[]{new ComponentText("Deeplearning4j - Spark Training Analysis", new StyleText.Builder().color(Color.BLACK).fontSize(20.0d).build())}));
        HashSet hashSet = new HashSet();
        for (String str2 : keySet) {
            if (sparkTrainingStats.defaultIncludeInPlots(str2)) {
                hashSet.add(str2);
            }
        }
        Collections.addAll(arrayList, getTrainingStatsTimelineChart(sparkTrainingStats, hashSet, j));
        for (String str3 : keySet) {
            ArrayList arrayList2 = new ArrayList(sparkTrainingStats.getValue(str3));
            Collections.sort(arrayList2, new StartTimeComparator());
            double[] dArr = new double[arrayList2.size()];
            double[] dArr2 = new double[arrayList2.size()];
            double d = Double.MAX_VALUE;
            double d2 = -1.7976931348623157E308d;
            for (int i = 0; i < dArr2.length; i++) {
                dArr[i] = i;
                dArr2[i] = ((EventStats) arrayList2.get(i)).getDurationMs();
                d = Math.min(d, dArr2[i]);
                d2 = Math.max(d2, dArr2[i]);
            }
            Component build2 = new ChartLine.Builder(str3, build).addSeries("Duration", dArr, dArr2).setYMin(d == d2 ? Double.valueOf(d - 1.0d) : null).setYMax(d == d2 ? Double.valueOf(d + 1.0d) : null).build();
            Component component = null;
            if (d != d2 && !arrayList2.isEmpty()) {
                component = getHistogram(dArr2, 20, str3, build);
            }
            arrayList.add(new ComponentDiv(new StyleDiv.Builder().width(100.0d, LengthUnit.Percent).build(), component != null ? new Component[]{build2, component} : new Component[]{build2}));
            if (!arrayList2.isEmpty() && ((arrayList2.get(0) instanceof ExampleCountEventStats) || (arrayList2.get(0) instanceof PartitionCountEventStats))) {
                boolean z = arrayList2.get(0) instanceof ExampleCountEventStats;
                double[] dArr3 = new double[arrayList2.size()];
                double d3 = Double.MAX_VALUE;
                double d4 = -1.7976931348623157E308d;
                for (int i2 = 0; i2 < dArr3.length; i2++) {
                    dArr3[i2] = z ? ((ExampleCountEventStats) arrayList2.get(i2)).getTotalExampleCount() : ((PartitionCountEventStats) arrayList2.get(i2)).getNumPartitions();
                    d3 = Math.min(d3, dArr3[i2]);
                    d4 = Math.max(d4, dArr3[i2]);
                }
                String str4 = str3 + " / " + (z ? "Number of Examples" : "Number of Partitions");
                Component build3 = new ChartLine.Builder(str4, build).addSeries(z ? "Examples" : "Partitions", dArr, dArr3).setYMin(d3 == d4 ? Double.valueOf(d3 - 1.0d) : null).setYMax(d3 == d4 ? Double.valueOf(d3 + 1.0d) : null).build();
                Component histogram = d3 != d4 ? getHistogram(dArr3, 20, str4, build) : null;
                arrayList.add(new ComponentDiv(new StyleDiv.Builder().width(100.0d, LengthUnit.Percent).build(), histogram != null ? new Component[]{build3, histogram} : new Component[]{build3}));
            }
        }
        SparkUtils.writeStringToFile(str, StaticPageUtil.renderHTML(arrayList), sparkContext);
    }

    private static Component[] getTrainingStatsTimelineChart(SparkTrainingStats sparkTrainingStats, Set<String> set, long j) {
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        long j2 = Long.MAX_VALUE;
        long j3 = Long.MIN_VALUE;
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            for (EventStats eventStats : sparkTrainingStats.getValue(it.next())) {
                hashSet2.add(eventStats.getMachineID());
                hashSet3.add(eventStats.getJvmID());
                hashSet.add(new Tuple3(eventStats.getMachineID(), eventStats.getJvmID(), Long.valueOf(eventStats.getThreadID())));
                j2 = Math.min(j2, eventStats.getStartTime());
                j3 = Math.max(j3, eventStats.getStartTime() + eventStats.getDurationMs());
            }
        }
        int i = 0;
        Iterator it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put((String) it2.next(), "PC " + i2);
        }
        int i3 = 0;
        Iterator it3 = hashSet3.iterator();
        while (it3.hasNext()) {
            int i4 = i3;
            i3++;
            hashMap2.put((String) it3.next(), "JVM " + i4);
        }
        int size = hashSet.size();
        ArrayList arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList, new TupleComparator());
        Color[] colors = getColors(set.size());
        HashMap hashMap3 = new HashMap();
        int i5 = 0;
        Iterator<String> it4 = set.iterator();
        while (it4.hasNext()) {
            int i6 = i5;
            i5++;
            hashMap3.put(it4.next(), colors[i6]);
        }
        ArrayList arrayList2 = new ArrayList();
        for (String str : set) {
            arrayList2.add(new ComponentDiv(new StyleDiv.Builder().backgroundColor((Color) hashMap3.get(str)).width(33.3d, LengthUnit.Percent).height(25.0d, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build(), new Component[]{new ComponentText(sparkTrainingStats.getShortNameForKey(str) + " - " + str, new StyleText.Builder().fontSize(11.0d).build())}));
        }
        ComponentDiv componentDiv = new ComponentDiv(new StyleDiv.Builder().width(100.0d, LengthUnit.Percent).build(), arrayList2);
        int i7 = (int) ((j3 - j2) / j);
        if (i7 < 1) {
            i7 = 1;
        }
        long[] jArr = new long[i7];
        long[] jArr2 = new long[i7];
        for (int i8 = 0; i8 < i7; i8++) {
            jArr[i8] = j2 + (i8 * j);
            jArr2[i8] = j2 + ((i8 + 1) * j);
        }
        ArrayList arrayList3 = new ArrayList();
        for (int i9 = 0; i9 < i7; i9++) {
            arrayList3.add(new ArrayList());
            for (int i10 = 0; i10 < size; i10++) {
                ((List) arrayList3.get(i9)).add(new ArrayList());
            }
        }
        for (String str2 : set) {
            for (EventStats eventStats2 : sparkTrainingStats.getValue(str2)) {
                if (eventStats2.getDurationMs() != 0) {
                    long startTime = eventStats2.getStartTime();
                    long durationMs = startTime + eventStats2.getDurationMs();
                    int i11 = -1;
                    for (int i12 = 0; i12 < i7; i12++) {
                        if (startTime >= jArr[i12] && startTime < jArr2[i12]) {
                            i11 = i12;
                        }
                    }
                    if (i11 == -1) {
                        i11 = i7 - 1;
                    }
                    ((List) ((List) arrayList3.get(i11)).get(arrayList.indexOf(new Tuple3(eventStats2.getMachineID(), eventStats2.getJvmID(), Long.valueOf(eventStats2.getThreadID()))))).add(new ChartTimeline.TimelineEntry(sparkTrainingStats.getShortNameForKey(str2), startTime, durationMs, (Color) hashMap3.get(str2)));
                }
            }
        }
        for (int i13 = 0; i13 < i7; i13++) {
            Iterator it5 = ((List) arrayList3.get(i13)).iterator();
            while (it5.hasNext()) {
                Collections.sort((List) it5.next(), new Comparator<ChartTimeline.TimelineEntry>() { // from class: org.deeplearning4j.spark.stats.StatsUtils.1
                    @Override // java.util.Comparator
                    public int compare(ChartTimeline.TimelineEntry timelineEntry, ChartTimeline.TimelineEntry timelineEntry2) {
                        return Long.compare(timelineEntry.getStartTimeMs(), timelineEntry2.getStartTimeMs());
                    }
                });
            }
        }
        StyleChart build = new StyleChart.Builder().width(1280.0d, LengthUnit.Px).height((50 * size) + 105, LengthUnit.Px).margin(LengthUnit.Px, 60, 20, 200, 10).build();
        ArrayList arrayList4 = new ArrayList(i7);
        for (int i14 = 0; i14 < i7; i14++) {
            ChartTimeline.Builder builder = new ChartTimeline.Builder("Timeline: Training Activities", build);
            int i15 = 0;
            for (List list : (List) arrayList3.get(i14)) {
                Tuple3 tuple3 = (Tuple3) arrayList.get(i15);
                builder.addLane(((String) hashMap.get(tuple3._1())) + ", " + ((String) hashMap2.get(tuple3._2())) + ", Thread " + tuple3._3(), list);
                i15++;
            }
            arrayList4.add(builder.build());
        }
        arrayList4.add(componentDiv);
        return (Component[]) arrayList4.toArray(new Component[arrayList4.size()]);
    }

    private static Color[] getColors(int i) {
        Color[] colorArr = new Color[i];
        double d = i <= 1 ? 1.0d : 1.0d / (i + 1);
        for (int i2 = 0; i2 < i; i2++) {
            if (i2 % 2 == 0) {
                colorArr[i2] = Color.getHSBColor(((float) d) * i2, 0.4f, 0.75f);
            } else {
                colorArr[i2] = Color.getHSBColor(((float) d) * i2, 1.0f, 1.0f);
            }
        }
        return colorArr;
    }

    private static Component getHistogram(double[] dArr, int i, String str, StyleChart styleChart) {
        double d = Double.MAX_VALUE;
        double d2 = -1.7976931348623157E308d;
        for (double d3 : dArr) {
            d = Math.min(d, d3);
            d2 = Math.max(d2, d3);
        }
        if (d == d2) {
            return null;
        }
        double[] dArr2 = new double[i + 1];
        int[] iArr = new int[i];
        double d4 = (d2 - d) / i;
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = d + (i2 * d4);
        }
        for (double d5 : dArr) {
            int i3 = 0;
            while (true) {
                if (i3 >= dArr2.length - 1) {
                    break;
                }
                if (d5 >= dArr2[i3] && d5 < dArr2[i3 + 1]) {
                    int i4 = i3;
                    iArr[i4] = iArr[i4] + 1;
                    break;
                }
                i3++;
            }
            if (d5 == dArr2[dArr2.length - 1]) {
                int length = iArr.length - 1;
                iArr[length] = iArr[length] + 1;
            }
        }
        ChartHistogram.Builder builder = new ChartHistogram.Builder(str, styleChart);
        for (int i5 = 0; i5 < dArr2.length - 1; i5++) {
            builder.addBin(dArr2[i5], dArr2[i5 + 1], iArr[i5]);
        }
        return builder.build();
    }
}
