flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#

Lifted version of jax.lax.cond.

The returned values from true_fun and false_fun must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.


class CondExample(nn.Module):
  def __call__(self, x, pred):
    self.variable('state', 'true_count', lambda: 0)
    self.variable('state', 'false_count', lambda: 0)
    def true_fn(mdl, x):
      mdl.variable('state', 'true_count').value += 1
      return nn.Dense(2, name='dense')(x)
    def false_fn(mdl, x):
      mdl.variable('state', 'false_count').value += 1
      return -nn.Dense(2, name='dense')(x)
    return nn.cond(pred, true_fn, false_fn, self, x)
  • pred – determines if true_fun or false_fun is evaluated.

  • true_fun – The function evalauted when pred is True. The signature is (module, *operands) -> T.

  • false_fun – The function evalauted when pred is False. The signature is (module, *operands) -> T.

  • mdl – A Module target to pass.

  • *operands – The arguments passed to true_fun and false_fun

  • variables – The variable collections passed to the conditional branches (default: all)

  • rngs – The PRNG sequences passed to the conditionals (default: all)


The result of the evaluated branch (true_fun or false_fun).