flax.linen.map_variables#

flax.linen.map_variables(target, mapped_collections=True, trans_in_fn=<function <lambda>>, trans_out_fn=<function <lambda>>, init=False, mutable=False, rngs=True, variables=True, methods=None)[source]#

Map Variables inside a module.

Example:

class OneBitDense(nn.Module):
  @nn.compact
  def __call__(self, x):
    def sign(x):
      return jax.tree_util.tree_map(jnp.sign, x)
    MapDense = nn.map_variables(nn.Dense, "params", sign, init=True)
    return MapDense(4)(x)
Parameters
  • target – the function to be transformed.

  • mapped_collections – the collection(s) to be transformed.

  • trans_in_fn – creates a view of the target variables.

  • trans_out_fn – transforms the updated variables in the view after mutation.

  • init – If True, variables are initialized before transformation.

  • mutable – If True, the mapped variable collections will be mutable.

  • rngs – PRNGSequences added to the transformed scope (default: all).

  • variables – Additional Variable collections added to the transformed scope. Besides those specified by target (default: all).

  • methods – If target is a Module, the methods of Module to map variables for.

Returns

a wrapped version of target that will map the specificied collections.