package ai.djl.mxnet.engine;

import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.GradientCollector;

/* loaded from: input_file:ai/djl/mxnet/engine/MxGradientCollector.class */
public class MxGradientCollector implements GradientCollector {
    /* JADX INFO: Access modifiers changed from: package-private */
    public MxGradientCollector() {
        if (setRecording(true)) {
            throw new IllegalStateException("Autograd Recording is already set to True. Please create autograd using try with resource ");
        }
        if (setTraining(true)) {
            throw new IllegalStateException("Autograd Training is already set to True. Please create autograd using try with resource ");
        }
    }

    public static boolean isRecording() {
        return JnaUtils.autogradIsRecording();
    }

    public static boolean isTraining() {
        return JnaUtils.autogradIsTraining();
    }

    public static boolean setRecording(boolean z) {
        return JnaUtils.autogradSetIsRecording(z);
    }

    public static boolean setTraining(boolean z) {
        return JnaUtils.autogradSetTraining(z);
    }

    public static Symbol getSymbol(NDManager nDManager, NDArray nDArray) {
        return new Symbol((MxNDManager) nDManager, JnaUtils.autogradGetSymbol(nDArray));
    }

    public void close() {
        setRecording(false);
        setTraining(false);
    }

    public void backward(NDArray nDArray) {
        backward(nDArray, false);
    }

    private void backward(NDArray nDArray, boolean z) {
        JnaUtils.autogradBackward(new NDList(new NDArray[]{nDArray}), z ? 1 : 0);
    }
}
