flax.linen.map_variables¶
- flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]¶
Map Variables inside a module.
Example:
class OneBitDense(nn.Module): @nn.compact def __call__(self, x): def sign(x): return jax.tree_map(jnp.sign, x) MapDense = nn.map_variables(nn.Dense, "params", sign, init=True) return MapDense(4)(x)
- Parameters
target (flax.linen.transforms.Target) – the function to be transformed.
mapped_collections (Union[bool, str, Collection[str], DenyList]) – the collection(s) to be transformed.
trans_in_fn (Callable[[...], Any]) – creates a view of the target variables.
trans_out_fn (Callable[[...], Any]) – transforms the updated variables in the view after mutation.
init (bool) – If True, variables are initialized before transformation.
mutable (bool) – If True, the mapped variable collections will be mutable.
rngs (Union[bool, str, Collection[str], DenyList]) – PRNGSequences added to the transformed scope (default: all).
variables (Union[bool, str, Collection[str], DenyList]) – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).
methods – If target is a Module, the methods of Module to map variables for.
- Returns
a wrapped version of
target
that will map the specificied collections.- Return type
flax.linen.transforms.Target