package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.GradientCollector;

/* loaded from: input_file:ai/djl/pytorch/engine/PtGradientCollector.class */
public class PtGradientCollector implements GradientCollector {
    public void backward(NDArray nDArray) {
        backward(nDArray, nDArray.getManager().ones(nDArray.getShape(), nDArray.getDataType()).toDevice(nDArray.getDevice(), false), false, false);
    }

    private void backward(NDArray nDArray, NDArray nDArray2, boolean z, boolean z2) {
        JniUtils.backward((PtNDArray) nDArray, (PtNDArray) nDArray2, z, z2);
    }

    public void close() {
    }
}
