flax.linen.tabulate#

flax.linen.tabulate(module, rngs, method=None, mutable=True, depth=None, exclude_methods=())[source]#

Returns a function that creates a summary of the Module represented as a table.

This function accepts most of the same arguments as Module.init, except that it returns a function of the form (*args, **kwargs) -> str where *args and **kwargs are passed to method (e.g. __call__) during the forward pass.

tabulate uses jax.eval_shape under the hood to run the forward computation without consuming any FLOPs or allocating memory.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))
tabulate_fn = nn.tabulate(Foo(), jax.random.PRNGKey(0))

print(tabulate_fn(x))

This gives the following output:

                    Foo Summary
┏━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ outputs       ┃ params               ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ Inputs  │ float32[16,9] │                      │
├─────────┼───────────────┼──────────────────────┤
│ Dense_0 │ float32[16,4] │ bias: float32[4]     │
│         │               │ kernel: float32[9,4] │
│         │               │                      │
│         │               │ 40 (160 B)           │
├─────────┼───────────────┼──────────────────────┤
│ Dense_1 │ float32[16,2] │ bias: float32[2]     │
│         │               │ kernel: float32[4,2] │
│         │               │                      │
│         │               │ 10 (40 B)            │
├─────────┼───────────────┼──────────────────────┤
│ Foo     │ float32[16,2] │                      │
├─────────┼───────────────┼──────────────────────┤
│         │         Total │ 50 (200 B)           │
└─────────┴───────────────┴──────────────────────┘

            Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters
  • module – The module to tabulate.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except ‘intermediates’ are mutable.

  • depth – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • exclude_methods – A sequence of strings that specifies which methods should be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.

Returns

A function that accepts the same *args and **kwargs of the forward pass (method) and returns a string with a tabular representation of the Modules.