flax.linen.map_variables
flax.linen.map_variables#
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#
Map Variables inside a module.
Example:
class OneBitDense(nn.Module): @nn.compact def __call__(self, x): def sign(x): return jax.tree_util.tree_map(jnp.sign, x) MapDense = nn.map_variables(nn.Dense, "params", sign, init=True) return MapDense(4)(x)
- Parameters
target – the function to be transformed.
mapped_collections – the collection(s) to be transformed.
trans_in_fn – creates a view of the target variables.
trans_out_fn – transforms the updated variables in the view after mutation.
init – If True, variables are initialized before transformation.
mutable – If True, the mapped variable collections will be mutable.
rngs – PRNGSequences added to the transformed scope (default: all).
variables – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).
methods – If target is a Module, the methods of Module to map variables for.
- Returns
a wrapped version of
target
that will map the specificied collections.