flax.linen.switch

Contents

flax.linen.switch#

flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[source]#

Lifted version of jax.lax.switch.

The returned values from branches must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.

Example:

>>> import flax.linen as nn

>>> class SwitchExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, index):
...     self.variable('state', 'a_count', lambda: 0)
...     self.variable('state', 'b_count', lambda: 0)
...     self.variable('state', 'c_count', lambda: 0)
...     def a_fn(mdl, x):
...       mdl.variable('state', 'a_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def b_fn(mdl, x):
...       mdl.variable('state', 'b_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     def c_fn(mdl, x):
...       mdl.variable('state', 'c_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     return nn.switch(index, [a_fn, b_fn, c_fn], self, x)

If you want to have a different parameter structure for each branch you should run all branches on initialization before calling switch:

>>> class MultiHeadSwitchExample(nn.Module):
...   def setup(self) -> None:
...     self.heads = [
...       nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]),
...       nn.Sequential([nn.Dense(11), nn.Dense(5)]),
...       nn.Dense(5),
...     ]
...
...   @nn.compact
...   def __call__(self, x, index):
...     def head_fn(i):
...       return lambda mdl, x: mdl.heads[i](x)
...     branches = [head_fn(i) for i in range(len(self.heads))]
...
...     # run all branches on init
...     if self.is_mutable_collection('params'):
...       for branch in branches:
...         _ = branch(self, x)
...
...     return nn.switch(index, branches, self, x)
Parameters
  • index – Integer scalar type, indicating which branch function to apply.

  • branches – Sequence of functions to be applied based on index. The signature of each function is (module, *operands) -> T.

  • mdl – A Module target to pass.

  • *operands – The arguments passed to the branches.

  • variables – The variable collections passed to the conditional branches (default: all)

  • rngs – The PRNG sequences passed to the conditionals (default: all)

Returns

The result of the evaluated branch.