Profiling#

flax.linen.enable_named_call()[source]#

Enables named call wrapping for labelling profile traces.

When named call wrapping is enabled all JAX ops executed in a Module will be run under jax.named_scope. The Module class name will show up around the operations belonging to that Module in the Tensorboard profiling UI, simplifying the profiling process.

Note that jax.named_scope only works for compiled functions (e.g.: using jax.jit or jax.pmap).

flax.linen.disable_named_call()[source]#

Disables named call wrapping.

See enable_named_call

flax.linen.override_named_call(enable=True)[source]#

Returns a context manager that enables/disables named call wrapping.

Parameters

enable – If true, enables named call wrapping for labelling profile traces. (see enabled_named_call).