flax.linen.with_partitioning

flax.linen.with_partitioning#

flax.linen.with_partitioning(fn, names, mesh=None)[source]#

Wraps a function’s return value with Partitioned.

Example:

>>> import flax.linen as nn
>>> kernel_init = nn.with_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
Parameters
  • fn – The function to be wrapped. Typically this is an initializer.

  • names – The logical axis passed to Partitioned.

  • mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.

Returns

A function wrapping fn that will return an instance of Partitioned.