Decorators#
- flax.linen.compact(fun)[source]#
Marks the given module method allowing inlined submodules.
Methods wrapped in @compact can define submodules directly within the method.
For instance:
@compact __call__(self, x, features): x = nn.Dense(features)(x) ...
At most one method in each Module may be wrapped with @compact.
- Parameters
fun – The Module method to mark as compact.
- Returns
The given function fun marked as compact.
- flax.linen.nowrap(fun)[source]#
Marks the given module method as a helper method that needn’t be wrapped.
Methods wrapped in @nowrap are private helper methods that needn’t be wrapped with the state handler or a separate named_call transform.
- This is needed in several concrete instances:
if you’re subclassing a method like Module.param and don’t want this overriden core function decorated with the state management wrapper.
If you want a method to be callable from an unbound Module (e.g.: a function of construction of arguments that doesn’t depend on params/RNGs)
For instance:
@nowrap def _make_dense(self, num_features): return nn.Dense(num_features) @compact def __call__(self, x): # now safe to use constructor helper even if using named_call dense = self._make_dense(self.num_features) return dense(x)
- Parameters
fun – The Module method to mark as nowrap.
- Returns
The given function fun marked as nowrap.
Summary
|
Marks the given module method allowing inlined submodules. |
|
Marks the given module method as a helper method that needn't be wrapped. |