flax.linen.while_loop#

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#

Lifted version of jax.lax.while_loop.

The lifted scope is passed to cond_fn and body_fn. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling body_fn once manually before calling while_loop if variable initialization is required.

Example:

class WhileLoopExample(nn.Module):
  @nn.compact
  def __call__(self, x):
    def cond_fn(mdl, c):
      return mdl.variables['state']['acc'] < 10
    def body_fn(mdl, c):
      acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
      acc.value += 1
      y = nn.Dense(c.shape[-1])(c)
      return y
    c = x
    if self.is_mutable_collection('params'):
      return body_fn(self, c)
    else:
      return nn.while_loop(cond_fn, body_fn, self, c,
                           carry_variables='state')

k = random.PRNGKey(0)
x = jnp.ones((2, 2))
intial_vars = WhileLoopExample().init(k, x)
result, state = WhileLoopExample().apply(intial_vars, x, mutable=['state'])
Parameters
  • cond_fn – Should return True as long as the loop should continue.

  • body_fn – The body of the while loop.

  • mdl – The Module which should be lifted into the loop.

  • init – The initial state passed to the loop

  • carry_variables – collections that are carried through the loop and are therefore mutable (default: none).

  • broadcast_variables – collections that are closed over and are therefore read-only (default: all collections)

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

Returns

The final state after executing the while loop.