flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]

Map Variables inside a module.


class OneBitDense(nn.Module):
  def __call__(self, x):
    def sign(x):
      return jax.tree_map(jnp.sign, x)
    MapDense = nn.map_variables(nn.Dense, "params", sign, init=True)
    return MapDense(4)(x)
  • target (flax.linen.transforms.Target) – the function to be transformed.

  • mapped_collections (Union[bool, str, Collection[str], DenyList]) – the collection(s) to be transformed.

  • trans_in_fn (Callable[[...], Any]) – creates a view of the target variables.

  • trans_out_fn (Callable[[...], Any]) – transforms the updated variables in the view after mutation.

  • init (bool) – If True, variables are initialized before transformation.

  • mutable (bool) – If True, the mapped variable collections will be mutable.

  • rngs (Union[bool, str, Collection[str], DenyList]) – PRNGSequences added to the transformed scope (default: all).

  • variables (Union[bool, str, Collection[str], DenyList]) – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).

  • methods – If target is a Module, the methods of Module to map variables for.


a wrapped version of target that will map the specificied collections.

Return type