public final class JnaUtils
extends java.lang.Object
| Modifier and Type | Class and Description |
|---|---|
static class |
JnaUtils.NumpyMode
An enum that enumerates the statuses of numpy mode.
|
| Modifier and Type | Field and Description |
|---|---|
static java.lang.String[] |
EMPTY_ARRAY |
static java.lang.String |
MXNET_THREAD_SAFE_PREDICTOR |
| Modifier and Type | Method and Description |
|---|---|
static void |
autogradBackward(ai.djl.ndarray.NDList array,
int retainGraph) |
static void |
autogradBackwardExecute(int numOutput,
ai.djl.ndarray.NDList array,
ai.djl.ndarray.NDArray outgrad,
int numVariables,
com.sun.jna.Pointer varHandles,
int retainGraph,
int createGraph,
int isTrain,
com.sun.jna.Pointer gradHandles,
com.sun.jna.Pointer gradSparseFormat) |
static com.sun.jna.Pointer |
autogradGetSymbol(ai.djl.ndarray.NDArray array) |
static boolean |
autogradIsRecording() |
static boolean |
autogradIsTraining() |
static void |
autogradMarkVariables(int numVar,
com.sun.jna.Pointer varHandles,
java.nio.IntBuffer reqsArray,
com.sun.jna.Pointer gradHandles) |
static boolean |
autogradSetIsRecording(boolean isRecording) |
static boolean |
autogradSetTraining(boolean isTraining) |
static MxNDArray[] |
cachedOpInvoke(MxNDManager manager,
com.sun.jna.Pointer cachedOpHandle,
MxNDArray[] inputs) |
static void |
checkCall(int ret) |
static CachedOp |
createCachedOp(MxSymbolBlock block,
MxNDManager manager)
Creates cached op flags.
|
static com.sun.jna.Pointer |
createNdArray(ai.djl.Device device,
ai.djl.ndarray.types.Shape shape,
ai.djl.ndarray.types.DataType dtype,
int size,
boolean delayedAlloc) |
static com.sun.jna.Pointer |
createSparseNdArray(ai.djl.ndarray.types.SparseFormat fmt,
ai.djl.Device device,
ai.djl.ndarray.types.Shape shape,
ai.djl.ndarray.types.DataType dtype,
ai.djl.ndarray.types.DataType[] auxDTypes,
ai.djl.ndarray.types.Shape[] auxShapes,
boolean delayedAlloc) |
static com.sun.jna.Pointer |
createSymbolFromFile(java.lang.String path) |
static void |
freeCachedOp(com.sun.jna.Pointer handle) |
static void |
freeNdArray(com.sun.jna.Pointer ndArray) |
static void |
freeSymbol(com.sun.jna.Pointer symbol) |
static java.util.Set<java.lang.String> |
getAllOpNames() |
static ai.djl.ndarray.types.DataType |
getDataType(com.sun.jna.Pointer ndArray) |
static ai.djl.Device |
getDevice(com.sun.jna.Pointer ndArray) |
static java.util.Set<java.lang.String> |
getFeatures() |
static int |
getGpuCount() |
static long[] |
getGpuMemory(ai.djl.Device device) |
static com.sun.jna.Pointer |
getGradient(com.sun.jna.Pointer handle) |
static java.util.Map<java.lang.String,FunctionInfo> |
getNdArrayFunctions() |
static ai.djl.ndarray.types.Shape |
getShape(com.sun.jna.Pointer ndArray) |
static ai.djl.ndarray.types.SparseFormat |
getStorageType(com.sun.jna.Pointer ndArray) |
static com.sun.jna.Pointer |
getSymbolInternals(com.sun.jna.Pointer symbol) |
static com.sun.jna.Pointer |
getSymbolOutput(com.sun.jna.Pointer symbol,
int index) |
static int |
getVersion() |
static ai.djl.util.PairList<com.sun.jna.Pointer,ai.djl.ndarray.types.SparseFormat> |
imperativeInvoke(com.sun.jna.Pointer function,
PointerArray inputs,
com.sun.jna.ptr.PointerByReference destRef,
ai.djl.util.PairList<java.lang.String,?> params) |
static java.util.List<java.util.List<ai.djl.ndarray.types.Shape>> |
inferShape(Symbol symbol,
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> args) |
static int |
isNumpyMode() |
static java.lang.String[] |
listSymbolArguments(com.sun.jna.Pointer symbol) |
static java.lang.String[] |
listSymbolAuxiliaryStates(com.sun.jna.Pointer symbol) |
static java.lang.String[] |
listSymbolNames(com.sun.jna.Pointer symbol) |
static java.lang.String[] |
listSymbolOutputs(com.sun.jna.Pointer symbol) |
static ai.djl.ndarray.NDList |
loadNdArray(MxNDManager manager,
java.nio.file.Path path,
ai.djl.Device device) |
static void |
ndArraySyncCopyFromNdArray(MxNDArray dest,
MxNDArray src,
int location) |
static FunctionInfo |
op(java.lang.String opName) |
static void |
parameterStoreClose(com.sun.jna.Pointer handle) |
static com.sun.jna.Pointer |
parameterStoreCreate(java.lang.String type) |
static void |
parameterStoreInit(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals) |
static void |
parameterStorePull(com.sun.jna.Pointer handle,
int num,
int[] keys,
ai.djl.ndarray.NDList vals,
int priority) |
static void |
parameterStorePull(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals,
int priority) |
static void |
parameterStorePush(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals,
int priority) |
static void |
parameterStoreSetUpdater(com.sun.jna.Pointer handle,
MxnetLibrary.MXKVStoreUpdater updater,
MxnetLibrary.MXKVStoreStrUpdater stringUpdater,
com.sun.jna.Pointer updaterHandle) |
static void |
parameterStoreSetUpdater(com.sun.jna.Pointer handle,
MxnetLibrary.MXKVStoreUpdater updater,
com.sun.jna.Pointer updaterHandle) |
static int |
randomSeed(int seed) |
static void |
setNumpyMode(JnaUtils.NumpyMode mode) |
static void |
syncCopyFromCPU(com.sun.jna.Pointer ndArray,
java.nio.Buffer data,
int len) |
static void |
syncCopyToCPU(com.sun.jna.Pointer ndArray,
com.sun.jna.Pointer data,
int len) |
static boolean |
useThreadSafePredictor() |
static void |
waitAll() |
static void |
waitToRead(com.sun.jna.Pointer ndArray) |
static void |
waitToWrite(com.sun.jna.Pointer ndArray) |
public static final java.lang.String[] EMPTY_ARRAY
public static final java.lang.String MXNET_THREAD_SAFE_PREDICTOR
public static int getVersion()
public static java.util.Set<java.lang.String> getAllOpNames()
public static java.util.Map<java.lang.String,FunctionInfo> getNdArrayFunctions()
public static FunctionInfo op(java.lang.String opName)
public static int getGpuCount()
public static long[] getGpuMemory(ai.djl.Device device)
public static java.util.Set<java.lang.String> getFeatures()
public static int randomSeed(int seed)
public static com.sun.jna.Pointer createNdArray(ai.djl.Device device,
ai.djl.ndarray.types.Shape shape,
ai.djl.ndarray.types.DataType dtype,
int size,
boolean delayedAlloc)
public static com.sun.jna.Pointer createSparseNdArray(ai.djl.ndarray.types.SparseFormat fmt,
ai.djl.Device device,
ai.djl.ndarray.types.Shape shape,
ai.djl.ndarray.types.DataType dtype,
ai.djl.ndarray.types.DataType[] auxDTypes,
ai.djl.ndarray.types.Shape[] auxShapes,
boolean delayedAlloc)
public static void ndArraySyncCopyFromNdArray(MxNDArray dest, MxNDArray src, int location)
public static ai.djl.ndarray.NDList loadNdArray(MxNDManager manager, java.nio.file.Path path, ai.djl.Device device)
public static void freeNdArray(com.sun.jna.Pointer ndArray)
public static void waitToRead(com.sun.jna.Pointer ndArray)
public static void waitToWrite(com.sun.jna.Pointer ndArray)
public static void waitAll()
public static void syncCopyToCPU(com.sun.jna.Pointer ndArray,
com.sun.jna.Pointer data,
int len)
public static void syncCopyFromCPU(com.sun.jna.Pointer ndArray,
java.nio.Buffer data,
int len)
public static ai.djl.util.PairList<com.sun.jna.Pointer,ai.djl.ndarray.types.SparseFormat> imperativeInvoke(com.sun.jna.Pointer function,
PointerArray inputs,
com.sun.jna.ptr.PointerByReference destRef,
ai.djl.util.PairList<java.lang.String,?> params)
public static ai.djl.ndarray.types.SparseFormat getStorageType(com.sun.jna.Pointer ndArray)
public static ai.djl.Device getDevice(com.sun.jna.Pointer ndArray)
public static ai.djl.ndarray.types.Shape getShape(com.sun.jna.Pointer ndArray)
public static ai.djl.ndarray.types.DataType getDataType(com.sun.jna.Pointer ndArray)
public static boolean autogradSetIsRecording(boolean isRecording)
public static boolean autogradSetTraining(boolean isTraining)
public static boolean autogradIsRecording()
public static boolean autogradIsTraining()
public static void autogradMarkVariables(int numVar,
com.sun.jna.Pointer varHandles,
java.nio.IntBuffer reqsArray,
com.sun.jna.Pointer gradHandles)
public static void autogradBackward(ai.djl.ndarray.NDList array,
int retainGraph)
public static void autogradBackwardExecute(int numOutput,
ai.djl.ndarray.NDList array,
ai.djl.ndarray.NDArray outgrad,
int numVariables,
com.sun.jna.Pointer varHandles,
int retainGraph,
int createGraph,
int isTrain,
com.sun.jna.Pointer gradHandles,
com.sun.jna.Pointer gradSparseFormat)
public static com.sun.jna.Pointer autogradGetSymbol(ai.djl.ndarray.NDArray array)
public static int isNumpyMode()
public static void setNumpyMode(JnaUtils.NumpyMode mode)
public static com.sun.jna.Pointer getGradient(com.sun.jna.Pointer handle)
public static com.sun.jna.Pointer parameterStoreCreate(java.lang.String type)
public static void parameterStoreClose(com.sun.jna.Pointer handle)
public static void parameterStoreInit(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals)
public static void parameterStorePush(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals,
int priority)
public static void parameterStorePull(com.sun.jna.Pointer handle,
int num,
int[] keys,
ai.djl.ndarray.NDList vals,
int priority)
public static void parameterStorePull(com.sun.jna.Pointer handle,
int num,
java.lang.String[] keys,
ai.djl.ndarray.NDList vals,
int priority)
public static void parameterStoreSetUpdater(com.sun.jna.Pointer handle,
MxnetLibrary.MXKVStoreUpdater updater,
MxnetLibrary.MXKVStoreStrUpdater stringUpdater,
com.sun.jna.Pointer updaterHandle)
public static void parameterStoreSetUpdater(com.sun.jna.Pointer handle,
MxnetLibrary.MXKVStoreUpdater updater,
com.sun.jna.Pointer updaterHandle)
public static com.sun.jna.Pointer getSymbolOutput(com.sun.jna.Pointer symbol,
int index)
public static java.lang.String[] listSymbolOutputs(com.sun.jna.Pointer symbol)
public static void freeSymbol(com.sun.jna.Pointer symbol)
public static java.lang.String[] listSymbolNames(com.sun.jna.Pointer symbol)
public static java.lang.String[] listSymbolArguments(com.sun.jna.Pointer symbol)
public static java.lang.String[] listSymbolAuxiliaryStates(com.sun.jna.Pointer symbol)
public static com.sun.jna.Pointer getSymbolInternals(com.sun.jna.Pointer symbol)
public static com.sun.jna.Pointer createSymbolFromFile(java.lang.String path)
public static java.util.List<java.util.List<ai.djl.ndarray.types.Shape>> inferShape(Symbol symbol, ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> args)
public static CachedOp createCachedOp(MxSymbolBlock block, MxNDManager manager)
data_indices : [0, 2, 4] Used to label input location, param_indices : [1, 3] Used to label param location
block - the MxSymbolBlock that loaded in the backendmanager - the NDManager used to create NDArraypublic static void freeCachedOp(com.sun.jna.Pointer handle)
public static MxNDArray[] cachedOpInvoke(MxNDManager manager, com.sun.jna.Pointer cachedOpHandle, MxNDArray[] inputs)
public static boolean useThreadSafePredictor()
public static void checkCall(int ret)