package org.deeplearning4j.spark.impl.paramavg.stats;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.spark.SparkContext;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.stats.BaseEventStats;
import org.deeplearning4j.spark.stats.EventStats;
import org.deeplearning4j.spark.stats.ExampleCountEventStats;
import org.deeplearning4j.spark.stats.StatsUtils;
import org.deeplearning4j.spark.time.TimeSource;
import org.deeplearning4j.spark.time.TimeSourceProvider;

/* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.class */
public class ParameterAveragingTrainingWorkerStats implements SparkTrainingStats {
    public static final String DEFAULT_DELIMITER = ",";
    public static final String FILENAME_BROADCAST_GET_STATS = "parameterAveragingWorkerBroadcastGetValueTimeMs.txt";
    public static final String FILENAME_INIT_STATS = "parameterAveragingWorkerInitTimeMs.txt";
    public static final String FILENAME_FIT_STATS = "parameterAveragingWorkerFitTimesMs.txt";
    private List<EventStats> parameterAveragingWorkerBroadcastGetValueTimeMs;
    private List<EventStats> parameterAveragingWorkerInitTimeMs;
    private List<EventStats> parameterAveragingWorkerFitTimesMs;
    public static final String PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS = "ParameterAveragingWorkerBroadcastGetValueTimeMs";
    public static final String PARAMETER_AVERAGING_WORKER_INIT_TIME_MS = "ParameterAveragingWorkerInitTimeMs";
    public static final String PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS = "ParameterAveragingWorkerFitTimesMs";
    private static Set<String> columnNames = Collections.unmodifiableSet(new LinkedHashSet(Arrays.asList(PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS, PARAMETER_AVERAGING_WORKER_INIT_TIME_MS, PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS)));

    /* loaded from: input_file:org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats$ParameterAveragingTrainingWorkerStatsHelper.class */
    public static class ParameterAveragingTrainingWorkerStatsHelper {
        private long broadcastStartTime;
        private long broadcastEndTime;
        private long initEndTime;
        private long lastFitStartTime;
        private List<EventStats> fitTimes = new ArrayList();
        private final TimeSource timeSource = TimeSourceProvider.getInstance();

        public void logBroadcastGetValueStart() {
            this.broadcastStartTime = this.timeSource.currentTimeMillis();
        }

        public void logBroadcastGetValueEnd() {
            this.broadcastEndTime = this.timeSource.currentTimeMillis();
        }

        public void logInitEnd() {
            this.initEndTime = this.timeSource.currentTimeMillis();
        }

        public void logFitStart() {
            this.lastFitStartTime = this.timeSource.currentTimeMillis();
        }

        public void logFitEnd(int i) {
            this.fitTimes.add(new ExampleCountEventStats(this.lastFitStartTime, this.timeSource.currentTimeMillis() - this.lastFitStartTime, i));
        }

        public ParameterAveragingTrainingWorkerStats build() {
            ArrayList arrayList = new ArrayList();
            arrayList.add(new BaseEventStats(this.broadcastStartTime, this.broadcastEndTime - this.broadcastStartTime));
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(new BaseEventStats(this.broadcastEndTime, this.initEndTime - this.broadcastEndTime));
            return new ParameterAveragingTrainingWorkerStats(arrayList, arrayList2, this.fitTimes);
        }
    }

    public ParameterAveragingTrainingWorkerStats(List<EventStats> list, List<EventStats> list2, List<EventStats> list3) {
        this.parameterAveragingWorkerBroadcastGetValueTimeMs = list;
        this.parameterAveragingWorkerInitTimeMs = list2;
        this.parameterAveragingWorkerFitTimesMs = list3;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public Set<String> getKeySet() {
        return columnNames;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public List<EventStats> getValue(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1590204278:
                if (str.equals(PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
            case -785332704:
                if (str.equals(PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case 1702472914:
                if (str.equals(PARAMETER_AVERAGING_WORKER_INIT_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return this.parameterAveragingWorkerBroadcastGetValueTimeMs;
            case true:
                return this.parameterAveragingWorkerInitTimeMs;
            case true:
                return this.parameterAveragingWorkerFitTimesMs;
            default:
                throw new IllegalArgumentException("Unknown key: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public String getShortNameForKey(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1590204278:
                if (str.equals(PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
            case -785332704:
                if (str.equals(PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case 1702472914:
                if (str.equals(PARAMETER_AVERAGING_WORKER_INIT_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "BroadcastGet";
            case true:
                return "ModelInit";
            case true:
                return "Fit";
            default:
                throw new IllegalArgumentException("Unknown key: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public boolean defaultIncludeInPlots(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1590204278:
                if (str.equals(PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS)) {
                    z = 2;
                    break;
                }
                break;
            case -785332704:
                if (str.equals(PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS)) {
                    z = false;
                    break;
                }
                break;
            case 1702472914:
                if (str.equals(PARAMETER_AVERAGING_WORKER_INIT_TIME_MS)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
            case true:
                return true;
            default:
                throw new IllegalArgumentException("Unknown key: \"" + str + "\"");
        }
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public void addOtherTrainingStats(SparkTrainingStats sparkTrainingStats) {
        if (!(sparkTrainingStats instanceof ParameterAveragingTrainingWorkerStats)) {
            throw new IllegalArgumentException("Cannot merge ParameterAveragingTrainingWorkerStats with " + (sparkTrainingStats != null ? sparkTrainingStats.getClass() : null));
        }
        ParameterAveragingTrainingWorkerStats parameterAveragingTrainingWorkerStats = (ParameterAveragingTrainingWorkerStats) sparkTrainingStats;
        this.parameterAveragingWorkerBroadcastGetValueTimeMs.addAll(parameterAveragingTrainingWorkerStats.parameterAveragingWorkerBroadcastGetValueTimeMs);
        this.parameterAveragingWorkerInitTimeMs.addAll(parameterAveragingTrainingWorkerStats.parameterAveragingWorkerInitTimeMs);
        this.parameterAveragingWorkerFitTimesMs.addAll(parameterAveragingTrainingWorkerStats.parameterAveragingWorkerFitTimesMs);
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public SparkTrainingStats getNestedTrainingStats() {
        return null;
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public String statsAsString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, PARAMETER_AVERAGING_WORKER_BROADCAST_GET_VALUE_TIME_MS));
        if (this.parameterAveragingWorkerBroadcastGetValueTimeMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.parameterAveragingWorkerBroadcastGetValueTimeMs, ",")).append("\n");
        }
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, PARAMETER_AVERAGING_WORKER_INIT_TIME_MS));
        if (this.parameterAveragingWorkerInitTimeMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.parameterAveragingWorkerInitTimeMs, ",")).append("\n");
        }
        sb.append(String.format(SparkTrainingStats.DEFAULT_PRINT_FORMAT, PARAMETER_AVERAGING_WORKER_FIT_TIMES_MS));
        if (this.parameterAveragingWorkerFitTimesMs == null) {
            sb.append("-\n");
        } else {
            sb.append(StatsUtils.getDurationAsString(this.parameterAveragingWorkerFitTimesMs, ",")).append("\n");
        }
        return sb.toString();
    }

    @Override // org.deeplearning4j.spark.api.stats.SparkTrainingStats
    public void exportStatFiles(String str, SparkContext sparkContext) throws IOException {
        StatsUtils.exportStats(this.parameterAveragingWorkerBroadcastGetValueTimeMs, str, FILENAME_BROADCAST_GET_STATS, ",", sparkContext);
        StatsUtils.exportStats(this.parameterAveragingWorkerInitTimeMs, str, FILENAME_INIT_STATS, ",", sparkContext);
        StatsUtils.exportStats(this.parameterAveragingWorkerFitTimesMs, str, FILENAME_FIT_STATS, ",", sparkContext);
    }

    public List<EventStats> getParameterAveragingWorkerBroadcastGetValueTimeMs() {
        return this.parameterAveragingWorkerBroadcastGetValueTimeMs;
    }

    public List<EventStats> getParameterAveragingWorkerInitTimeMs() {
        return this.parameterAveragingWorkerInitTimeMs;
    }

    public List<EventStats> getParameterAveragingWorkerFitTimesMs() {
        return this.parameterAveragingWorkerFitTimesMs;
    }

    public void setParameterAveragingWorkerBroadcastGetValueTimeMs(List<EventStats> list) {
        this.parameterAveragingWorkerBroadcastGetValueTimeMs = list;
    }

    public void setParameterAveragingWorkerInitTimeMs(List<EventStats> list) {
        this.parameterAveragingWorkerInitTimeMs = list;
    }

    public void setParameterAveragingWorkerFitTimesMs(List<EventStats> list) {
        this.parameterAveragingWorkerFitTimesMs = list;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParameterAveragingTrainingWorkerStats)) {
            return false;
        }
        ParameterAveragingTrainingWorkerStats parameterAveragingTrainingWorkerStats = (ParameterAveragingTrainingWorkerStats) obj;
        if (!parameterAveragingTrainingWorkerStats.canEqual(this)) {
            return false;
        }
        List<EventStats> parameterAveragingWorkerBroadcastGetValueTimeMs = getParameterAveragingWorkerBroadcastGetValueTimeMs();
        List<EventStats> parameterAveragingWorkerBroadcastGetValueTimeMs2 = parameterAveragingTrainingWorkerStats.getParameterAveragingWorkerBroadcastGetValueTimeMs();
        if (parameterAveragingWorkerBroadcastGetValueTimeMs == null) {
            if (parameterAveragingWorkerBroadcastGetValueTimeMs2 != null) {
                return false;
            }
        } else if (!parameterAveragingWorkerBroadcastGetValueTimeMs.equals(parameterAveragingWorkerBroadcastGetValueTimeMs2)) {
            return false;
        }
        List<EventStats> parameterAveragingWorkerInitTimeMs = getParameterAveragingWorkerInitTimeMs();
        List<EventStats> parameterAveragingWorkerInitTimeMs2 = parameterAveragingTrainingWorkerStats.getParameterAveragingWorkerInitTimeMs();
        if (parameterAveragingWorkerInitTimeMs == null) {
            if (parameterAveragingWorkerInitTimeMs2 != null) {
                return false;
            }
        } else if (!parameterAveragingWorkerInitTimeMs.equals(parameterAveragingWorkerInitTimeMs2)) {
            return false;
        }
        List<EventStats> parameterAveragingWorkerFitTimesMs = getParameterAveragingWorkerFitTimesMs();
        List<EventStats> parameterAveragingWorkerFitTimesMs2 = parameterAveragingTrainingWorkerStats.getParameterAveragingWorkerFitTimesMs();
        return parameterAveragingWorkerFitTimesMs == null ? parameterAveragingWorkerFitTimesMs2 == null : parameterAveragingWorkerFitTimesMs.equals(parameterAveragingWorkerFitTimesMs2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParameterAveragingTrainingWorkerStats;
    }

    public int hashCode() {
        List<EventStats> parameterAveragingWorkerBroadcastGetValueTimeMs = getParameterAveragingWorkerBroadcastGetValueTimeMs();
        int hashCode = (1 * 59) + (parameterAveragingWorkerBroadcastGetValueTimeMs == null ? 43 : parameterAveragingWorkerBroadcastGetValueTimeMs.hashCode());
        List<EventStats> parameterAveragingWorkerInitTimeMs = getParameterAveragingWorkerInitTimeMs();
        int hashCode2 = (hashCode * 59) + (parameterAveragingWorkerInitTimeMs == null ? 43 : parameterAveragingWorkerInitTimeMs.hashCode());
        List<EventStats> parameterAveragingWorkerFitTimesMs = getParameterAveragingWorkerFitTimesMs();
        return (hashCode2 * 59) + (parameterAveragingWorkerFitTimesMs == null ? 43 : parameterAveragingWorkerFitTimesMs.hashCode());
    }

    public String toString() {
        return "ParameterAveragingTrainingWorkerStats(parameterAveragingWorkerBroadcastGetValueTimeMs=" + getParameterAveragingWorkerBroadcastGetValueTimeMs() + ", parameterAveragingWorkerInitTimeMs=" + getParameterAveragingWorkerInitTimeMs() + ", parameterAveragingWorkerFitTimesMs=" + getParameterAveragingWorkerFitTimesMs() + ")";
    }
}
