Inspection

Contents

Inspection#

flax.linen.tabulate(module, rngs, depth=None, show_repeated=False, mutable=DenyList(deny='intermediates'), console_kwargs=None, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **kwargs)[source]#

Returns a function that creates a summary of the Module represented as a table.

This function accepts most of the same arguments and internally calls Module.init, except that it returns a function of the form (*args, **kwargs) -> str where *args and **kwargs are passed to method (e.g. __call__) during the forward pass.

tabulate uses jax.eval_shape under the hood to run the forward computation without consuming any FLOPs or allocating memory.

Additional arguments can be passed into the console_kwargs argument, for example, {‘width’: 120}. For a full list of console_kwargs arguments, see: https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     h = nn.Dense(4)(x)
...     return nn.Dense(2)(h)

>>> x = jnp.ones((16, 9))
>>> tabulate_fn = nn.tabulate(
...     Foo(), jax.random.key(0), compute_flops=True, compute_vjp_flops=True)

>>> # print(tabulate_fn(x))

This gives the following output:

                                       Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ path    ┃ module ┃ inputs        ┃ outputs       ┃ flops ┃ vjp_flops ┃ params          ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│         │ Foo    │ float32[16,9] │ float32[16,2] │ 1504  │ 4460      │                 │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_0 │ Dense  │ float32[16,9] │ float32[16,4] │ 1216  │ 3620      │ bias:           │
│         │        │               │               │       │           │ float32[4]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[9,4]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 40 (160 B)      │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│ Dense_1 │ Dense  │ float32[16,4] │ float32[16,2] │ 288   │ 840       │ bias:           │
│         │        │               │               │       │           │ float32[2]      │
│         │        │               │               │       │           │ kernel:         │
│         │        │               │               │       │           │ float32[4,2]    │
│         │        │               │               │       │           │                 │
│         │        │               │               │       │           │ 10 (40 B)       │
├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤
│         │        │               │               │       │     Total │ 50 (200 B)      │
└─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘

                               Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Note: vjp_flops returns 0 if the module is not differentiable.

Parameters
  • module – The module to tabulate.

  • rngs – The rngs for the variable collections as passed to Module.init.

  • depth – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • show_repeated – If True, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is False.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except ‘intermediates’ are mutable.

  • console_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are {‘force_terminal’: True, ‘force_jupyter’: False}.

  • table_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table constructor.

  • column_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.table.Table.add_column when adding columns to the table.

  • compute_flops – whether to include a flops column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).

  • compute_vjp_flops – whether to include a vjp_flops column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of compute_flops.

  • **kwargs – Additional arguments passed to Module.init.

Returns

A function that accepts the same *args and **kwargs of the forward pass (method) and returns a string with a tabular representation of the Modules.