flax.linen.with_partitioning#
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
Wraps a function’s return value with Partitioned.
Example:
kernel_init = with_partitioning( nn.initializers.lecun_normal, (None, "data")) partitioned_dense = nn.Dense(features, kernel_init=kernel_init)
- Parameters
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
Partitioned
.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
- Returns
A function wrapping
fn
that will return an instance ofPartitioned
.