public class TfSymbolBlock
extends java.lang.Object
implements ai.djl.nn.SymbolBlock
| Constructor and Description |
|---|
TfSymbolBlock(org.tensorflow.SavedModelBundle bundle) |
| Modifier and Type | Method and Description |
|---|---|
void |
cast(ai.djl.ndarray.types.DataType dataType) |
void |
clear() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeInput() |
ai.djl.ndarray.NDList |
forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
ai.djl.nn.BlockList |
getChildren() |
ai.djl.nn.ParameterList |
getDirectParameters() |
ai.djl.ndarray.types.Shape[] |
getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.nn.ParameterList |
getParameters() |
ai.djl.ndarray.types.Shape |
getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.ndarray.types.Shape[] |
initialize(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.DataType dataType,
ai.djl.ndarray.types.Shape... inputShapes) |
boolean |
isInitialized() |
void |
loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is) |
void |
removeLastBlock() |
void |
saveParameters(java.io.DataOutputStream os) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer,
java.lang.String paramName) |
public void removeLastBlock()
removeLastBlock in interface ai.djl.nn.SymbolBlockpublic ai.djl.ndarray.NDList forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward in interface ai.djl.nn.Blockpublic void setInitializer(ai.djl.training.initializer.Initializer initializer)
setInitializer in interface ai.djl.nn.Blockpublic void setInitializer(ai.djl.training.initializer.Initializer initializer,
java.lang.String paramName)
setInitializer in interface ai.djl.nn.Blockpublic ai.djl.ndarray.types.Shape[] initialize(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.DataType dataType,
ai.djl.ndarray.types.Shape... inputShapes)
initialize in interface ai.djl.nn.Blockpublic boolean isInitialized()
isInitialized in interface ai.djl.nn.Blockpublic void cast(ai.djl.ndarray.types.DataType dataType)
cast in interface ai.djl.nn.Blockpublic void clear()
clear in interface ai.djl.nn.Blockpublic ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
describeInput in interface ai.djl.nn.Blockpublic ai.djl.nn.BlockList getChildren()
getChildren in interface ai.djl.nn.Blockpublic ai.djl.nn.ParameterList getDirectParameters()
getDirectParameters in interface ai.djl.nn.Blockpublic ai.djl.nn.ParameterList getParameters()
getParameters in interface ai.djl.nn.Blockpublic ai.djl.ndarray.types.Shape getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes)
getParameterShape in interface ai.djl.nn.Blockpublic ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes)
getOutputShapes in interface ai.djl.nn.Blockpublic void saveParameters(java.io.DataOutputStream os)
saveParameters in interface ai.djl.nn.Blockpublic void loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is)
loadParameters in interface ai.djl.nn.Block