Profiling#
- flax.linen.enable_named_call()[source]#
Enables named call wrapping for labelling profile traces.
When named call wrapping is enabled all JAX ops executed in a Module will be run under
jax.named_scope
. TheModule
class name will show up around the operations belonging to that Module in the Tensorboard profiling UI, simplifying the profiling process.Note that
jax.named_scope
only works for compiled functions (e.g.: using jax.jit or jax.pmap).
- flax.linen.override_named_call(enable=True)[source]#
Returns a context manager that enables/disables named call wrapping.
- Parameters
enable – If true, enables named call wrapping for labelling profile traces. (see
enabled_named_call
).
Summary
Enables named call wrapping for labelling profile traces. |
|
Disables named call wrapping. |
|
|
Returns a context manager that enables/disables named call wrapping. |