public class MxSymbolBlock
extends ai.djl.nn.AbstractBlock
implements ai.djl.nn.SymbolBlock
MxSymbolBlock is the MXNet implementation of SymbolBlock.
You can create a MxSymbolBlock using Model.load(java.nio.file.Path,
String).
| Constructor and Description |
|---|
MxSymbolBlock(ai.djl.ndarray.NDManager manager,
Symbol symbol)
Constructs a
MxSymbolBlock for a Symbol. |
| Modifier and Type | Method and Description |
|---|---|
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeInput() |
ai.djl.ndarray.NDList |
forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
java.util.List<ai.djl.nn.Parameter> |
getAllParameters()
Returns the list of inputs and parameter NDArrays.
|
java.util.List<java.lang.String> |
getLayerNames()
Returns the layers' name.
|
ai.djl.ndarray.types.Shape[] |
getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.ndarray.types.Shape |
getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes) |
Symbol |
getSymbol()
Returns the Symbolic graph from the model.
|
void |
loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is) |
void |
removeLastBlock() |
void |
saveParameters(java.io.DataOutputStream os) |
void |
setInputNames(java.util.List<java.lang.String> inputNames)
Sets the names of the input data.
|
addChildBlock, addParameter, addParameter, addParameter, beforeInitialize, cast, clear, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, toStringpublic MxSymbolBlock(ai.djl.ndarray.NDManager manager,
Symbol symbol)
MxSymbolBlock for a Symbol.
You can create a MxSymbolBlock using Model.load(java.nio.file.Path,
String).
manager - the manager to use for the blocksymbol - the symbol containing the block's symbolic graphpublic void setInputNames(java.util.List<java.lang.String> inputNames)
inputNames - the names of the input datapublic java.util.List<ai.djl.nn.Parameter> getAllParameters()
public java.util.List<java.lang.String> getLayerNames()
public Symbol getSymbol()
Symbol objectpublic ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
describeInput in interface ai.djl.nn.BlockdescribeInput in class ai.djl.nn.AbstractBlockpublic ai.djl.ndarray.NDList forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward in interface ai.djl.nn.Blockpublic ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes)
getOutputShapes in interface ai.djl.nn.Blockpublic void removeLastBlock()
removeLastBlock in interface ai.djl.nn.SymbolBlockpublic ai.djl.ndarray.types.Shape getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes)
getParameterShape in interface ai.djl.nn.BlockgetParameterShape in class ai.djl.nn.AbstractBlockpublic void saveParameters(java.io.DataOutputStream os)
throws java.io.IOException
saveParameters in interface ai.djl.nn.BlocksaveParameters in class ai.djl.nn.AbstractBlockjava.io.IOExceptionpublic void loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is)
throws java.io.IOException,
ai.djl.MalformedModelException
loadParameters in interface ai.djl.nn.BlockloadParameters in class ai.djl.nn.AbstractBlockjava.io.IOExceptionai.djl.MalformedModelException