flax.linen.map_variables

Contents

flax.linen.map_variables#

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#

Map Variables inside a module.

map_variables can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> import flax.linen as nn
...
>>> class CausalDense(nn.Module):
...   '''A dense layer that masks the weights such that the output is
...   causal, i.e. output i only depends on input <= i.
...   '''
...   features: int
...
...   def apply_mask(self, variables):
...     return (jax.tree_map(jnp.triu, variables)
...             if not self.is_initializing() else variables)
...
...   def setup(self):
...     # temporary class
...     _CausalDense = nn.map_variables(
...       nn.Dense, 'params', self.apply_mask, init=self.is_initializing())
...     self.dense = _CausalDense(features=self.features, use_bias=False)
...
...   def __call__(self, x):
...     return self.dense(x)
...
>>> module = CausalDense(features=5)
>>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
Parameters
  • target – the module or function to be transformed.

  • mapped_collections – the collection(s) to be transformed.

  • trans_in_fn – modifies the variables before applying the module or function.

  • trans_out_fn – modifies the variables after applying the module or function, it is only applied if either init or mutable are not False.

  • init – If True, variables are initialized before transformation.

  • mutable – If True, the mapped variable collections will be mutable.

  • rngs – PRNGSequences added to the transformed scope (default: all).

  • variables – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).

  • methods – If target is a Module, the methods of Module to map variables for.

Returns

a wrapped version of target that will map the specified collections.