package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.TrainingDivergedException;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.TrainingMetric;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/mxnet/engine/MxTrainer.class */
public class MxTrainer implements Trainer {
    private static final Logger logger = LoggerFactory.getLogger(MxTrainer.class);
    private MxModel model;
    private MxNDManager manager;
    private Metrics metrics;
    private TrainingListener listener;
    private Device[] devices;
    private ParameterStore parameterStore;
    private List<TrainingMetric> trainingMetrics;
    private List<TrainingMetric> validateMetrics;
    private Loss trainingLoss;
    private Loss validationLoss;
    long batchBeginTime;
    private boolean gradientsChecked;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MxTrainer(MxModel mxModel, TrainingConfig trainingConfig) {
        this.model = mxModel;
        this.manager = (MxNDManager) mxModel.getNDManager().newSubManager();
        this.devices = trainingConfig.getDevices();
        this.trainingLoss = trainingConfig.getLossFunction();
        if (this.trainingLoss == null) {
            throw new IllegalArgumentException("You must specify a loss for the trainer");
        }
        this.validationLoss = this.trainingLoss.duplicate();
        this.trainingMetrics = new ArrayList(trainingConfig.getTrainingMetrics());
        this.validateMetrics = new ArrayList();
        this.trainingMetrics.forEach(trainingMetric -> {
            this.validateMetrics.add(trainingMetric.duplicate());
        });
        this.trainingMetrics.add(this.trainingLoss);
        this.validateMetrics.add(this.validationLoss);
        LocalParameterServer localParameterServer = new LocalParameterServer(trainingConfig.getOptimizer());
        this.parameterStore = new ParameterStore(this.manager, false);
        this.parameterStore.setParameterServer(localParameterServer, this.devices);
    }

    public void initialize(Shape... shapeArr) {
        this.model.getBlock().initialize(this.model.getNDManager(), this.model.getDataType(), shapeArr);
        this.model.getBlock().getParameters().forEach(pair -> {
            for (Device device : this.devices) {
                this.parameterStore.getValue((Parameter) pair.getValue(), device);
            }
        });
    }

    public GradientCollector newGradientCollector() {
        return new MxGradientCollector();
    }

    public void trainBatch(Batch batch) {
        Batch[] split = batch.split(this.devices, false);
        MxGradientCollector mxGradientCollector = new MxGradientCollector();
        Throwable th = null;
        try {
            for (Batch batch2 : split) {
                NDList data = batch2.getData();
                NDList labels = batch2.getLabels();
                NDList forward = forward(data);
                long nanoTime = System.nanoTime();
                mxGradientCollector.backward(this.trainingLoss.getLoss(labels, forward));
                addMetric("backward", nanoTime);
                long nanoTime2 = System.nanoTime();
                updateTrainingMetrics(labels, forward);
                addMetric("training-metrics", nanoTime2);
            }
            addMetric("train", this.batchBeginTime);
            this.batchBeginTime = System.nanoTime();
            if (this.listener != null) {
                this.listener.onTrainingBatch();
            }
        } finally {
            if (mxGradientCollector != null) {
                if (0 != 0) {
                    try {
                        mxGradientCollector.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    mxGradientCollector.close();
                }
            }
        }
    }

    public NDList forward(NDList nDList) {
        long nanoTime = System.nanoTime();
        try {
            NDList forward = this.model.getBlock().forward(this.parameterStore, nDList);
            addMetric("forward", nanoTime);
            return forward;
        } catch (Throwable th) {
            addMetric("forward", nanoTime);
            throw th;
        }
    }

    public void validateBatch(Batch batch) {
        long nanoTime = System.nanoTime();
        for (Batch batch2 : batch.split(this.devices, false)) {
            updateValidationMetrics(batch2.getLabels(), forward(batch2.getData()));
        }
        addMetric("validate", nanoTime);
        if (this.listener != null) {
            this.listener.onValidationBatch();
        }
    }

    public void step() {
        if (!this.gradientsChecked) {
            checkGradients();
        }
        long nanoTime = System.nanoTime();
        this.parameterStore.updateAllParameters();
        addMetric("step", nanoTime);
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    public void setTrainingListener(TrainingListener trainingListener) {
        this.listener = trainingListener;
    }

    private void updateTrainingMetrics(NDList nDList, NDList nDList2) {
        MxGradientCollector.setRecording(false);
        MxGradientCollector.setTraining(false);
        this.trainingMetrics.forEach(trainingMetric -> {
            trainingMetric.update(nDList, nDList2);
        });
        addMetric("train", (TrainingMetric) this.trainingLoss);
        if (Float.isNaN(this.trainingLoss.getValue())) {
            throw new TrainingDivergedException("The Loss became NaN, try reduce learning rate,add clipGradient option to your optimizer, check input data and loss calculation.");
        }
        this.trainingMetrics.forEach(trainingMetric2 -> {
            addMetric("train", trainingMetric2);
        });
        MxGradientCollector.setRecording(true);
        MxGradientCollector.setTraining(true);
    }

    private void updateValidationMetrics(NDList nDList, NDList nDList2) {
        this.validateMetrics.forEach(trainingMetric -> {
            trainingMetric.update(nDList, nDList2);
        });
        this.validateMetrics.forEach(trainingMetric2 -> {
            addMetric("validate", trainingMetric2);
        });
    }

    public void resetTrainingMetrics() {
        this.trainingMetrics.forEach((v0) -> {
            v0.reset();
        });
        this.validateMetrics.forEach((v0) -> {
            v0.reset();
        });
        if (this.listener != null) {
            this.listener.onEpoch();
        }
    }

    public Loss getLoss() {
        return this.trainingLoss;
    }

    public Loss getValidationLoss() {
        return this.validationLoss;
    }

    public Model getModel() {
        return this.model;
    }

    public Metrics getMetrics() {
        return this.metrics;
    }

    public final <T extends TrainingMetric> T getTrainingMetric(Class<T> cls) {
        Iterator<TrainingMetric> it = this.trainingMetrics.iterator();
        while (it.hasNext()) {
            T t = (T) it.next();
            if (cls.isInstance(t)) {
                return t;
            }
        }
        return null;
    }

    public <T extends TrainingMetric> T getValidationMetric(Class<T> cls) {
        Iterator<TrainingMetric> it = this.validateMetrics.iterator();
        while (it.hasNext()) {
            T t = (T) it.next();
            if (cls.isInstance(t)) {
                return t;
            }
        }
        return null;
    }

    public NDManager getManager() {
        return this.manager;
    }

    private void checkGradients() {
        ArrayList arrayList = new ArrayList();
        this.model.getBlock().getParameters().values().stream().filter((v0) -> {
            return v0.requireGradient();
        }).forEach(parameter -> {
            arrayList.add(this.parameterStore.getValue(parameter, this.devices[0]).getGradient());
        });
        NDList nDList = new NDList((NDArray[]) arrayList.stream().map((v0) -> {
            return v0.sum();
        }).toArray(i -> {
            return new NDArray[i];
        }));
        NDArray stack = NDArrays.stack(nDList);
        nDList.close();
        NDArray sum = stack.sum();
        float[] floatArray = sum.toFloatArray();
        sum.close();
        stack.close();
        float f = 0.0f;
        for (float f2 : floatArray) {
            f += f2;
        }
        if (f == 0.0f) {
            throw new IllegalStateException("Gradient values are all zeros, please call gradientCollector.backward() onyour target NDArray (usually loss), before calling step() ");
        }
        this.gradientsChecked = true;
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Model was not closed explicitly: {}", getClass().getSimpleName());
            }
            close();
        }
        super.finalize();
    }

    public void close() {
        this.parameterStore.sync();
        this.manager.close();
    }

    private void addMetric(String str, long j) {
        if (this.metrics == null || j <= 0) {
            return;
        }
        this.metrics.addMetric(str, Long.valueOf(System.nanoTime() - j));
    }

    private void addMetric(String str, TrainingMetric trainingMetric) {
        if (this.metrics != null) {
            this.metrics.addMetric(str + '_' + trainingMetric.getName(), Float.valueOf(trainingMetric.getValue()));
        }
    }
}
