flax.linen.cond

Contents

flax.linen.cond#

flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#

Lifted version of jax.lax.cond.

The returned values from true_fun and false_fun must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.

Example:

>>> import flax.linen as nn

>>> class CondExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, pred):
...     self.variable('state', 'true_count', lambda: 0)
...     self.variable('state', 'false_count', lambda: 0)
...     def true_fn(mdl, x):
...       mdl.variable('state', 'true_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def false_fn(mdl, x):
...       mdl.variable('state', 'false_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     return nn.cond(pred, true_fn, false_fn, self, x)
Parameters
  • pred – determines if true_fun or false_fun is evaluated.

  • true_fun – The function evaluated when pred is True. The signature is (module, *operands) -> T.

  • false_fun – The function evaluated when pred is False. The signature is (module, *operands) -> T.

  • mdl – A Module target to pass.

  • *operands – The arguments passed to true_fun and false_fun

  • variables – The variable collections passed to the conditional branches (default: all)

  • rngs – The PRNG sequences passed to the conditionals (default: all)

Returns

The result of the evaluated branch (true_fun or false_fun).