flax.linen.map_variables#
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#
Map Variables inside a module.
map_variables
can be used to transform the variables inside a module both before and after the module is applied. This is useful among other things for masking the weights of a module without having to modify the module itself.- Example::
>>> import jax >>> import jax.numpy as jnp >>> import flax.linen as nn ... >>> class CausalDense(nn.Module): ... '''A dense layer that masks the weights such that the output is ... causal, i.e. output i only depends on input <= i. ... ''' ... features: int ... ... def apply_mask(self, variables): ... return (jax.tree_map(jnp.triu, variables) ... if not self.is_initializing() else variables) ... ... def setup(self): ... # temporary class ... _CausalDense = nn.map_variables( ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) ... self.dense = _CausalDense(features=self.features, use_bias=False) ... ... def __call__(self, x): ... return self.dense(x) ... >>> module = CausalDense(features=5) >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5)))
- Parameters
target – the module or function to be transformed.
mapped_collections – the collection(s) to be transformed.
trans_in_fn – modifies the variables before applying the module or function.
trans_out_fn – modifies the variables after applying the module or function, it is only applied if either
init
ormutable
are not False.init – If True, variables are initialized before transformation.
mutable – If True, the mapped variable collections will be mutable.
rngs – PRNGSequences added to the transformed scope (default: all).
variables – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).
methods – If target is a Module, the methods of Module to map variables for.
- Returns
a wrapped version of
target
that will map the specified collections.