Transformations#

JAX transformations on Modules.

Jax functional transformations operate on pure functions. Flax extends these transformations to also operate on Module’s which have stateful variables and PRNG sequences. We refer to these extended versions as “lifted transformations”.

A lifted transformation can be applied to a Module class or a function that takes a Module instance as its first argument.

flax.linen.vmap(target, variable_axes=FrozenDict({}), split_rngs=FrozenDict({}), in_axes=0, out_axes=0, axis_size=None, axis_name=None, spmd_axis_name=None, metadata_params={}, methods=None)[source]#

A lifted version of jax.vmap.

See jax.vmap for the unlifted batch transform in Jax.

vmap can be used to add a batch axis to a Module. For example we could create a version of Dense with a batch axis that does not share parameters:

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': 0},
...     split_rngs={'params': True})

By using variable_axes={'params': 0}, we indicate that the parameters themselves are mapped over and therefore not shared along the mapped axis. Consequently, we also split the ‘params’ RNG, otherwise the parameters would be initialized identically along the mapped axis.

Similarly, vmap could be used to add a batch axis with parameter sharing:

>>> import flax.linen as nn
>>> BatchDense = nn.vmap(
...     nn.Dense,
...     in_axes=0, out_axes=0,
...     variable_axes={'params': None},
...     split_rngs={'params': False})

Here we use variable_axes={'params': None} to indicate the parameter variables are shared along the mapped axis. Consequently, the ‘params’ RNG must also be shared.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are lifted into the batching transformation. Use None to indicate a broadcasted collection or an integer to map over an axis. For example, passing in variable_axes={'params': None} will indicate that the parameter variables should be shared along the mapped axis.

  • split_rngs – Split PRNG sequences will be different for each index of the batch dimension. Unsplit PRNGs will be broadcasted.

  • in_axes – Specifies the mapping of the input arguments (see jax.vmap).

  • out_axes – Specifies the mapping of the return value (see jax.vmap).

  • axis_size – Specifies the size of the batch axis. This only needs to be specified if it cannot be derived from the input arguments.

  • axis_name – Specifies a name for the batch axis. Can be used together with parallel reduction primitives (e.g. jax.lax.pmean, jax.lax.ppermute, etc.). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.

  • methods – If target is a Module, the methods of Module to vmap over.

  • spmd_axis_name – Axis name added to any pjit sharding constraints appearing in fn. See also google/flax.

  • metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.

Returns

A batched/vectorized version of target, with the same arguments but with extra axes at positions indicated by in_axes, and the same return value, but with extra axes at positions indicated by out_axes.

flax.linen.scan(target, variable_axes=FrozenDict({}), variable_broadcast=False, variable_carry=False, split_rngs=FrozenDict({}), in_axes=0, out_axes=0, length=None, reverse=False, unroll=1, data_transform=None, metadata_params={}, methods=None, _split_transpose=False)[source]#

A lifted version of jax.lax.scan.

See jax.lax.scan for the unlifted scan in Jax.

To improve consistency with vmap, this version of scan uses in_axes and out_axes to determine which arguments are scanned over and along which axis.

scan distinguishes between 3 different types of values inside the loop:

  1. scan: a value that is iterated over in a loop. All scan values must have the same size in the axis they are scanned over. Scanned outputs will be stacked along the scan axis.

  2. carry: A carried value is updated at each loop iteration. It must have the same shape and dtype throughout the loop.

  3. broadcast: a value that is closed over by the loop. When a variable is broadcasted they are typically initialized inside the loop body but independent of the loop variables.

The target should have the signature (module, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.

Example:

>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
...
>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...     ScanLSTM = nn.scan(
...       nn.LSTMCell, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     lstm = ScanLSTM(self.features)
...     input_shape =  x[:, 0].shape
...     carry = lstm.initialize_carry(jax.random.key(0), input_shape)
...     carry, x = lstm(carry, x)
...     return x
...
>>> x = jnp.ones((4, 12, 7))
>>> module = LSTM(features=32)
>>> y, variables = module.init_with_output(jax.random.key(0), x)

Note that when providing a function to nn.scan, the scanning happens over all arguments starting from the third argument, as specified by in_axes. The previous example could also be written using the functional form as:

>>> class LSTM(nn.Module):
...   features: int
...
...   @nn.compact
...   def __call__(self, x):
...
...     cell = nn.LSTMCell(self.features)
...     def body_fn(cell, carry, x):
...       carry, y = cell(carry, x)
...       return carry, y
...     scan = nn.scan(
...       body_fn, variable_broadcast="params",
...       split_rngs={"params": False}, in_axes=1, out_axes=1)
...
...     input_shape =  x[:, 0].shape
...     carry = cell.initialize_carry(
...       jax.random.key(0), input_shape)
...     carry, x = scan(cell, carry, x)
...     return x
...
>>> module = LSTM(features=32)
>>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7)))

You can also use scan to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when you have a sequence of identical layers that you want to apply iteratively to an input. For example:

>>> class ResidualMLPBlock(nn.Module):
...   @nn.compact
...   def __call__(self, x, _):
...     h = nn.Dense(features=2)(x)
...     h = nn.relu(h)
...     return x + h, None
...
>>> class ResidualMLP(nn.Module):
...   n_layers: int = 4
...
...   @nn.compact
...   def __call__(self, x):
...     ScanMLP = nn.scan(
...       ResidualMLPBlock, variable_axes={'params': 0},
...       variable_broadcast=False, split_rngs={'params': True},
...       length=self.n_layers)
...     x, _ = ScanMLP()(x, None)
...     return x
...
>>> model = ResidualMLP(n_layers=4)
>>> variables = model.init(jax.random.key(42), jnp.ones((1, 2)))

To reduce both compilation and memory usage, you can use remat_scan() which will in addition checkpoint each layer in the scan loop.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variable_axes – the variable collections that are scanned over.

  • variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

  • in_axes – Specifies the axis to scan over for the arguments. Should be a prefix tree of the arguments. Use flax.core.broadcast to feed an entire input to each iteration of the scan body.

  • out_axes – Specifies the axis to scan over for the return value. Should be a prefix tree of the return value.

  • length – Specifies the number of loop iterations. This only needs to be specified if it cannot be derived from the scan arguments.

  • reverse – If true, scan from end to start in reverse order.

  • unroll – how many scan iterations to unroll within a single iteration of a loop (default: 1).

  • data_transform – optional function to transform raw functional-core variable and rng groups inside lifted scan body_fn, intended for inline SPMD annotations.

  • metadata_params – arguments dict passed to AxisMetadata instances in the variable tree.

  • methods – If target is a Module, the methods of Module to scan over.

  • _split_transpose – An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature.

Returns

The scan function with the signature (module, carry, *xs) -> (carry, ys), where xs and ys are the scan values that go in and out of the loop.

flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), static_argnames=(), donate_argnums=(), device=None, backend=None, methods=None)[source]#

Lifted version of jax.jit.

Parameters
  • target – a Module or a function taking a Module as its first argument.

  • variables – The variable collections that are lifted. By default all collections are lifted.

  • rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.

  • static_argnums – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().

  • static_argnames – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.

  • device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend – a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • methods – If target is a Module, the methods of Module to jit.

Returns

A wrapped version of target, set up for just-in-time compilation.

flax.linen.remat(target, variables=True, rngs=True, concrete=False, prevent_cse=True, static_argnums=(), policy=None, methods=None)#

Lifted version of jax.checkpoint.

Checkpointing is a technique for reducing memory usage by recomputing activations during backpropagation. When training large models, it can be helpful to checkpoint parts of the model to trade off memory usage for additional computation.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CheckpointedMLP(nn.Module):
...   @nn.checkpoint
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(128)(x)
...     x = nn.relu(x)
...     x = nn.Dense(1)(x)
...     return x
...
>>> model = CheckpointedMLP()
>>> variables = model.init(jax.random.key(0), jnp.ones((1, 16)))

This function is aliased to remat just like jax.remat.

Parameters
  • target – a Module or a function taking a Module as its first argument. intermediate computations will be re-computed when computing gradients for the target.

  • variables – The variable collections that are lifted. By default all collections are lifted.

  • rngs – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.

  • concrete – Optional, boolean indicating whether fun may involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions with jax.jit() it can lead to some extra computation.

  • prevent_cse – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under a jit or pmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside a scan, this CSE prevention mechanism is unnecessary, in which case prevent_cse should be set to False.

  • static_argnums – Optional, int or sequence of ints, indicates which argument values on which to specialize for tracing and caching purposes. Specifying arguments as static can avoid ConcretizationTypeErrors when tracing, but at the cost of more retracing overheads.

  • policy – Experimental checkpoint policy, see jax.checkpoint.

  • methods – An optional list of method names that will be lifted, if methods is None (default) only the __call__ method will be lifted. If``target`` is a function, methods is ignored.

Returns

A wrapped version of target. When computing gradients intermediate computations will be re-computed on the backward pass.

flax.linen.remat_scan(target, lengths=(), policy=None, variable_broadcast=False, variable_carry=False, variable_axes=FrozenDict({True: 0}), split_rngs=FrozenDict({True: True}))[source]#

Combines remat and scan for memory efficiency and constant time compilation.

remat_scan allows for constant compile times and sublinear memory usage with respect to model depth. At a small constant penalty. This is typically beneficial for very deep models.

Example:

>>> import flax.linen as nn

>>> class BigModel(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10))
...     # 100x dense with O(sqrt(N)) memory for gradient computation
...     return DenseStack(8, name="dense_stack")(x)
Parameters
  • target – a Module or a function taking a Module as its first argument.

  • lengths – number of loop iterations at the given level. The total number of iterations n = prod(lengths). each loop is rematerialized. This way the memory consumption is proportional to n^(1 / d) where d = len(lengths). Minimal memory consumptions requires tuning the lengths such that the same amount of memory is consumed at each level of the nested loop.

  • policy – Experimental checkpoint policy, see jax.checkpoint.

  • variable_broadcast – Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn.

  • variable_carry – Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes.

  • variable_axes – the variable collections that are scanned over. Defaults to {True: 0}.

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations. Defaults to {True: True}.

Returns

A wrapped version of target that repeats itself prod(lengths) times.

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#

Map Variables inside a module.

map_variables can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CausalDense(nn.Module):
...   '''A dense layer that masks the weights such that the output is
...   causal, i.e. output i only depends on input <= i.
...   '''
...   features: int
...
...   def apply_mask(self, variables):
...     return (jax.tree_util.tree_map(jnp.triu, variables)
...             if not self.is_initializing() else variables)
...
...   def setup(self):
...     # temporary class
...     _CausalDense = nn.map_variables(
...       nn.Dense, 'params', self.apply_mask, init=self.is_initializing())
...     self.dense = _CausalDense(features=self.features, use_bias=False)
...
...   def __call__(self, x):
...     return self.dense(x)
...
>>> module = CausalDense(features=5)
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
Parameters
  • target – the module or function to be transformed.

  • mapped_collections – the collection(s) to be transformed.

  • trans_in_fn – modifies the variables before applying the module or function.

  • trans_out_fn – modifies the variables after applying the module or function, it is only applied if either init or mutable are not False.

  • init – If True, variables are initialized before transformation.

  • mutable – If True, the mapped variable collections will be mutable.

  • rngs – PRNGSequences added to the transformed scope (default: all).

  • variables – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).

  • methods – If target is a Module, the methods of Module to map variables for.

Returns

a wrapped version of target that will map the specified collections.

flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[source]#

A lifted version of jax.jvp.

See jax.jvp for the unlifted Jacobian-vector product (forward gradient).

Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by fn using Module.variables:

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     p = self.param('test', nn.initializers._init(), ())
...     return p * x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     scale = LearnScale()
...     vars_t = jax.tree_util.tree_map(jnp.ones_like,
...                                     scale.variables.get('params', {}))
...     _, out_t = nn.jvp(
...         lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),),
...         variable_tangents={'params': vars_t})
...     return out_t

Example:

>>> def learn_scale(scope, x):
...   p = scope.param('scale', nn.initializers.zeros_init(), ())
...   return p * x

>>> def f(scope, x):
...   vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
...   x, out_t = lift.jvp(
...       learn_scale, scope, (x,), (jnp.zeros_like(x),),
...       variable_tangents={'params': vars_t})
...   return out_t
Parameters
  • fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.

  • mdl – The module of which the variables will be differentiated.

  • primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of fun.

  • tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals.

  • variable_tangents – A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent.

  • variables – other variables collections that are available in fn but do not receive a tangent.

  • rngs – the prngs that are available inside fn.

Returns

A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.

flax.linen.vjp(fn, mdl, *primals, has_aux=False, reduce_axes=(), vjp_variables='params', variables=True, rngs=True, multi_scope=False)[source]#

A lifted version of jax.vjp.

See jax.vjp for the unlifted vector-Jacobian product (backward gradient).

Note that a gradient is returned for all variables in the collections specified by vjp_variables. However, the backward function only expects a cotangent for the return value of fn. If variables require a co-tangent as well they can be returned from fn using Module.variables.

Example:

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     p = self.param('scale', nn.initializers.zeros_init(), ())
...     return p * x * y

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x, y):
...     z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
...     params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape))
...     return z, params_grad, x_grad, y_grad
Parameters
  • fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.

  • mdl – The module of which the variables will be differentiated.

  • *primals – A sequence of primal values at which the Jacobian of fn should be evaluated. The length of primals should be equal to the number of positional parameters to fn. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.

  • has_aux – Optional, bool. Indicates whether fn returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fn implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if 'batch' is a named batch axis, vjp(f, *args, reduce_axes=('batch',)) will create a VJP function that sums over the batch while vjp(f, *args) will create a per-example VJP.

  • vjp_variables – The vjpfun will return a cotangent vector for all variable collections specified by this filter.

  • variables – other variables collections that are available inside fn but do not receive a cotangent.

  • rngs – the prngs that are available inside fn.

  • multi_scope – for Modules containing multiple scopes from outside modules passed in, allow for variable gradients to be returned for multiple scopes instead of erroring.

Returns

If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fn(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fn evaluated at primals. If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fn.

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#

Lifted version of jax.custom_vjp.

forward_fn and backward_fn together define a custom vjp for fn. The original fn will run in case a vjp (backward gradient) is not computed.

The forward_fn receives the same arguments as fn but is expected to return a tuple containing the output of fn(mdl, *args) and the residuals that are passed to backward_fn.

The backward_fn receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents.

Note that the vjp function returned by nn.vjp can be passed as residual and used in the backward_fn. The scope is unavailable during the backward pass. If the module is required in backward_fn, a snapshot of the variables can be taken and returned as a residual in the forward_fn.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def f(mdl, x):
...       return mdl(x)
...
...     def fwd(mdl, x):
...       return nn.vjp(f, mdl, x)
...
...     def bwd(vjp_fn, y_t):
...       params_t, *inputs_t = vjp_fn(y_t)
...       params_t = jax.tree_util.tree_map(jnp.sign, params_t)
...       return (params_t, *inputs_t)
...
...     sign_grad = nn.custom_vjp(
...         f, forward_fn=fwd, backward_fn=bwd)
...     return sign_grad(nn.Dense(1), x).reshape(())

>>> x = jnp.ones((2,))
>>> variables = Foo().init(jax.random.key(0), x)
>>> grad = jax.grad(Foo().apply)(variables, x)
Parameters
  • fn – The function to define a custom_vjp for.

  • forward_fn – A function with the same arguments as fn returning an tuple with the original output and the residuals that will be passsed to backward_fn.

  • backward_fn – arguments are passed as (*nondiff_args, residuals, tangents) The function should return a tuple containing the tangents for the variable in the collections specified by grad_vars and the input arguments (except the module and nondiff args).

  • grad_vars – The collections for which a vjp will be computed (default: “params”).

  • nondiff_argnums – arguments for which no vjp is computed.

Returns

A function with the same signature as fn with the custom vjp.

flax.linen.while_loop(cond_fn, body_fn, mdl, init, carry_variables=False, broadcast_variables=True, split_rngs=FrozenDict({}))[source]#

Lifted version of jax.lax.while_loop.

The lifted scope is passed to cond_fn and body_fn. Broadcasted variables are immutable. The carry variable are mutable but cannot change shape and dtype. This also means you cannot initialize variables inside the body. Consider calling body_fn once manually before calling while_loop if variable initialization is required.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class WhileLoopExample(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def cond_fn(mdl, c):
...       return mdl.variables['state']['acc'] < 10
...     def body_fn(mdl, c):
...       acc = mdl.variable('state', 'acc', lambda: jnp.array(0))
...       acc.value += 1
...       y = nn.Dense(c.shape[-1])(c)
...       return y
...     c = x
...     if self.is_mutable_collection('params'):
...       return body_fn(self, c)
...     else:
...       return nn.while_loop(cond_fn, body_fn, self, c,
...                             carry_variables='state')

>>> k = jax.random.key(0)
>>> x = jnp.ones((2, 2))
>>> initial_vars = WhileLoopExample().init(k, x)
>>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state'])
Parameters
  • cond_fn – Should return True as long as the loop should continue.

  • body_fn – The body of the while loop.

  • mdl – The Module which should be lifted into the loop.

  • init – The initial state passed to the loop

  • carry_variables – collections that are carried through the loop and are therefore mutable (default: none).

  • broadcast_variables – collections that are closed over and are therefore read-only (default: all collections)

  • split_rngs – Split PRNG sequences will be different for each loop iterations. If split is False the PRNGs will be the same across iterations.

Returns

The final state after executing the while loop.

flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#

Lifted version of jax.lax.cond.

The returned values from true_fun and false_fun must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.

Example:

>>> import flax.linen as nn

>>> class CondExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, pred):
...     self.variable('state', 'true_count', lambda: 0)
...     self.variable('state', 'false_count', lambda: 0)
...     def true_fn(mdl, x):
...       mdl.variable('state', 'true_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def false_fn(mdl, x):
...       mdl.variable('state', 'false_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     return nn.cond(pred, true_fn, false_fn, self, x)
Parameters
  • pred – determines if true_fun or false_fun is evaluated.

  • true_fun – The function evaluated when pred is True. The signature is (module, *operands) -> T.

  • false_fun – The function evaluated when pred is False. The signature is (module, *operands) -> T.

  • mdl – A Module target to pass.

  • *operands – The arguments passed to true_fun and false_fun

  • variables – The variable collections passed to the conditional branches (default: all)

  • rngs – The PRNG sequences passed to the conditionals (default: all)

Returns

The result of the evaluated branch (true_fun or false_fun).

flax.linen.switch(index, branches, mdl, *operands, variables=True, rngs=True)[source]#

Lifted version of jax.lax.switch.

The returned values from branches must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.

Example:

>>> import flax.linen as nn

>>> class SwitchExample(nn.Module):
...   @nn.compact
...   def __call__(self, x, index):
...     self.variable('state', 'a_count', lambda: 0)
...     self.variable('state', 'b_count', lambda: 0)
...     self.variable('state', 'c_count', lambda: 0)
...     def a_fn(mdl, x):
...       mdl.variable('state', 'a_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     def b_fn(mdl, x):
...       mdl.variable('state', 'b_count').value += 1
...       return -nn.Dense(2, name='dense')(x)
...     def c_fn(mdl, x):
...       mdl.variable('state', 'c_count').value += 1
...       return nn.Dense(2, name='dense')(x)
...     return nn.switch(index, [a_fn, b_fn, c_fn], self, x)

If you want to have a different parameter structure for each branch you should run all branches on initialization before calling switch:

>>> class MultiHeadSwitchExample(nn.Module):
...   def setup(self) -> None:
...     self.heads = [
...       nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]),
...       nn.Sequential([nn.Dense(11), nn.Dense(5)]),
...       nn.Dense(5),
...     ]
...
...   @nn.compact
...   def __call__(self, x, index):
...     def head_fn(i):
...       return lambda mdl, x: mdl.heads[i](x)
...     branches = [head_fn(i) for i in range(len(self.heads))]
...
...     # run all branches on init
...     if self.is_mutable_collection('params'):
...       for branch in branches:
...         _ = branch(self, x)
...
...     return nn.switch(index, branches, self, x)
Parameters
  • index – Integer scalar type, indicating which branch function to apply.

  • branches – Sequence of functions to be applied based on index. The signature of each function is (module, *operands) -> T.

  • mdl – A Module target to pass.

  • *operands – The arguments passed to the branches.

  • variables – The variable collections passed to the conditional branches (default: all)

  • rngs – The PRNG sequences passed to the conditionals (default: all)

Returns

The result of the evaluated branch.