package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/training/listener/LoggingTrainingListener.class */
public class LoggingTrainingListener implements TrainingListener {
    private static final Logger logger = LoggerFactory.getLogger(LoggingTrainingListener.class);
    private int numEpochs;
    private ProgressBar trainingProgressBar;
    private ProgressBar validateProgressBar;

    @Override // ai.djl.training.listener.TrainingListener
    public void onEpoch(Trainer trainer) {
        logger.info("Epoch {} finished.", Integer.valueOf(this.numEpochs + 1));
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            Loss loss = trainer.getLoss();
            logger.info("Train: {}", getEvaluatorsStatus(metrics, trainer.getEvaluators(), EvaluatorTrainingListener.TRAIN_EPOCH, 32767));
            if (metrics.hasMetric(EvaluatorTrainingListener.metricName(loss, EvaluatorTrainingListener.VALIDATE_EPOCH))) {
                logger.info("Validate: {}", getEvaluatorsStatus(metrics, trainer.getEvaluators(), EvaluatorTrainingListener.VALIDATE_EPOCH, 32767));
            } else {
                logger.info("validation has not been run.");
            }
        }
        this.numEpochs++;
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.trainingProgressBar == null) {
            this.trainingProgressBar = new ProgressBar("Training", batchData.getBatch().getProgressTotal());
        }
        this.trainingProgressBar.update(batchData.getBatch().getProgress(), getTrainingStatus(trainer, batchData.getBatch().getSize()));
    }

    private String getTrainingStatus(Trainer trainer, int i) {
        Metrics metrics = trainer.getMetrics();
        if (metrics == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(getEvaluatorsStatus(metrics, trainer.getEvaluators(), EvaluatorTrainingListener.TRAIN_PROGRESS, 2));
        if (metrics.hasMetric("train")) {
            sb.append(String.format(", speed: %.2f items/sec", Float.valueOf(i / (((float) metrics.latestMetric("train").getValue().longValue()) / 1.0E9f))));
        }
        return sb.toString();
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.validateProgressBar == null) {
            this.validateProgressBar = new ProgressBar("Validating", batchData.getBatch().getProgressTotal());
        }
        this.validateProgressBar.update(batchData.getBatch().getProgress());
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingBegin(Trainer trainer) {
        List<Device> devices = trainer.getDevices();
        logger.info("Training on: {}.", (devices.size() == 1 && Device.Type.CPU.equals(devices.get(0).getDeviceType())) ? Device.cpu().toString() : devices.size() + " GPUs");
        logger.info(String.format("Load %s Engine Version %s in %.3f ms.", Engine.getInstance().getEngineName(), Engine.getInstance().getVersion(), Float.valueOf(((float) (System.nanoTime() - System.nanoTime())) / 1000000.0f)));
    }

    @Override // ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
        Metrics metrics = trainer.getMetrics();
        if (metrics == null) {
            return;
        }
        if (metrics.hasMetric("train")) {
            logger.info(String.format("train P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) metrics.percentile("train", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) metrics.percentile("train", 90).getValue().longValue()) / 1000000.0f)));
        }
        logger.info(String.format("forward P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) metrics.percentile("forward", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) metrics.percentile("forward", 90).getValue().longValue()) / 1000000.0f)));
        logger.info(String.format("training-metrics P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) metrics.percentile("training-metrics", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) metrics.percentile("training-metrics", 90).getValue().longValue()) / 1000000.0f)));
        logger.info(String.format("backward P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) metrics.percentile("backward", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) metrics.percentile("backward", 90).getValue().longValue()) / 1000000.0f)));
        logger.info(String.format("step P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) metrics.percentile("step", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) metrics.percentile("step", 90).getValue().longValue()) / 1000000.0f)));
        logger.info(String.format("epoch P50: %.3f s, P90: %.3f s", Float.valueOf(((float) metrics.percentile("epoch", 50).getValue().longValue()) / 1.0E9f), Float.valueOf(((float) metrics.percentile("epoch", 90).getValue().longValue()) / 1.0E9f)));
    }

    private String getEvaluatorsStatus(Metrics metrics, List<Evaluator> list, String str, int i) {
        ArrayList arrayList = new ArrayList(i + 1);
        int i2 = 0;
        Iterator<Evaluator> it = list.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Evaluator next = it.next();
            i2++;
            if (i2 > i) {
                arrayList.add("...");
                break;
            }
            String metricName = EvaluatorTrainingListener.metricName(next, str);
            if (metrics.hasMetric(metricName)) {
                arrayList.add(String.format("%s: %.2f", next.getName(), Float.valueOf(metrics.latestMetric(metricName).getValue().floatValue())));
            } else {
                arrayList.add(String.format("%s: _", next.getName()));
            }
        }
        return String.join(", ", arrayList);
    }
}
