public class MxModel
extends ai.djl.BaseModel
MxModel is the MXNet implementation of Model.
MxModel contains all the methods in Model to load and process a model. In addition, it provides MXNet Specific functionality, such as getSymbol to obtain the Symbolic graph and getParameters to obtain the parameter NDArrays
| Modifier and Type | Method and Description |
|---|---|
void |
cast(ai.djl.ndarray.types.DataType dataType) |
void |
close() |
java.lang.String[] |
getArtifactNames() |
void |
load(java.nio.file.Path modelPath,
java.lang.String prefix,
java.util.Map<java.lang.String,java.lang.Object> options)
Loads the MXNet model from a specified location.
|
<I,O> ai.djl.inference.Predictor<I,O> |
newPredictor(ai.djl.translate.Translator<I,O> translator) |
ai.djl.training.Trainer |
newTrainer(ai.djl.training.TrainingConfig trainingConfig) |
java.lang.String |
toString() |
describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getName, getNDManager, getProperty, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setPropertypublic void load(java.nio.file.Path modelPath,
java.lang.String prefix,
java.util.Map<java.lang.String,java.lang.Object> options)
throws java.io.IOException,
ai.djl.MalformedModelException
MXNet engine looks for {MODEL_NAME}-symbol.json and {MODEL_NAME}-{EPOCH}.params files in the specified directory. By default, MXNet engine will pick up the latest epoch of the parameter file. However, users can explicitly specify an epoch to be loaded:
Map<String, String> options = new HashMap<>()
options.put("epoch", "3");
model.load(modelPath, "squeezenet", options);
modelPath - the directory of the modelprefix - the model file name or path prefixoptions - load model options, see documentation for the specific enginejava.io.IOException - Exception for file loadingai.djl.MalformedModelExceptionpublic ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig)
public <I,O> ai.djl.inference.Predictor<I,O> newPredictor(ai.djl.translate.Translator<I,O> translator)
public void cast(ai.djl.ndarray.types.DataType dataType)
public java.lang.String[] getArtifactNames()
public void close()
public java.lang.String toString()
toString in class java.lang.Object