Class MPSCNNBatchNormalizationNode

  • All Implemented Interfaces:
    MPSNNTrainableNode, NSObject

    public class MPSCNNBatchNormalizationNode
    extends MPSNNFilterNode
    implements MPSNNTrainableNode
    MPSCNNBatchNormalizationNode A node representing batch normalization for inference or training Batch normalization operates differently for inference and training. For inference, the normalization is done according to a static statistical representation of data saved during training. For training, this representation is ever evolving. In the low level MPS batch normalization interface, during training, the batch normalization is broken up into two steps: calculation of the statistical representation of input data, followed by normalization once the statistics are known for the entire batch. These are MPSCNNBatchNormalizationStatistics and MPSCNNBatchNormalization, respectively. When this node appears in a graph and is not required to produce a MPSCNNBatchNormalizationState -- that is, MPSCNNBatchNormalizationNode.resultState is not used within the graph -- then it operates in inference mode and new batch-only statistics are not calculated. When this state node is consumed, then the node is assumed to be in training mode and new statistics will be calculated and written to the MPSCNNBatchNormalizationState and passed along to the MPSCNNBatchNormalizationGradient and MPSCNNBatchNormalizationStatisticsGradient as necessary. This should allow you to construct an identical sequence of nodes for inference and training and expect the right thing to happen.
    • Constructor Detail

      • MPSCNNBatchNormalizationNode

        protected MPSCNNBatchNormalizationNode​(org.moe.natj.general.Pointer peer)
    • Method Detail

      • accessInstanceVariablesDirectly

        public static boolean accessInstanceVariablesDirectly()
      • allocWithZone

        public static java.lang.Object allocWithZone​(org.moe.natj.general.ptr.VoidPtr zone)
      • automaticallyNotifiesObserversForKey

        public static boolean automaticallyNotifiesObserversForKey​(java.lang.String key)
      • cancelPreviousPerformRequestsWithTarget

        public static void cancelPreviousPerformRequestsWithTarget​(java.lang.Object aTarget)
      • cancelPreviousPerformRequestsWithTargetSelectorObject

        public static void cancelPreviousPerformRequestsWithTargetSelectorObject​(java.lang.Object aTarget,
                                                                                 org.moe.natj.objc.SEL aSelector,
                                                                                 java.lang.Object anArgument)
      • classFallbacksForKeyedArchiver

        public static NSArray<java.lang.String> classFallbacksForKeyedArchiver()
      • classForKeyedUnarchiver

        public static org.moe.natj.objc.Class classForKeyedUnarchiver()
      • debugDescription_static

        public static java.lang.String debugDescription_static()
      • description_static

        public static java.lang.String description_static()
      • flags

        public long flags()
        Options controlling how batch normalization is calculated Default: MPSCNNBatchNormalizationFlagsDefault
      • hash_static

        public static long hash_static()
      • instanceMethodSignatureForSelector

        public static NSMethodSignature instanceMethodSignatureForSelector​(org.moe.natj.objc.SEL aSelector)
      • instancesRespondToSelector

        public static boolean instancesRespondToSelector​(org.moe.natj.objc.SEL aSelector)
      • isSubclassOfClass

        public static boolean isSubclassOfClass​(org.moe.natj.objc.Class aClass)
      • keyPathsForValuesAffectingValueForKey

        public static NSSet<java.lang.String> keyPathsForValuesAffectingValueForKey​(java.lang.String key)
      • new_objc

        public static java.lang.Object new_objc()
      • resolveClassMethod

        public static boolean resolveClassMethod​(org.moe.natj.objc.SEL sel)
      • resolveInstanceMethod

        public static boolean resolveInstanceMethod​(org.moe.natj.objc.SEL sel)
      • setFlags

        public void setFlags​(long value)
        Options controlling how batch normalization is calculated Default: MPSCNNBatchNormalizationFlagsDefault
      • setTrainingStyle

        public void setTrainingStyle​(long value)
        Description copied from interface: MPSNNTrainableNode
        Configure whether and when neural network training parameters are updated Default: MPSNNTrainingStyleUpdateDeviceGPU
        Specified by:
        setTrainingStyle in interface MPSNNTrainableNode
      • setVersion_static

        public static void setVersion_static​(long aVersion)
      • superclass_static

        public static org.moe.natj.objc.Class superclass_static()
      • trainingStyle

        public long trainingStyle()
        Description copied from interface: MPSNNTrainableNode
        Configure whether and when neural network training parameters are updated Default: MPSNNTrainingStyleUpdateDeviceGPU
        Specified by:
        trainingStyle in interface MPSNNTrainableNode
      • version_static

        public static long version_static()