SPMD#
Utilities for working with pjit and partitioned models.
This module introduces axis_rules, logical_to_mesh_axes, logical_to_mesh, with_logical_constraint for appyling pjit sharding constraints in terms of “logical named axes” rather than pjit’s default mesh axes.
Additionally the LogicallyPartitioned metadata wrapper is defined as well as the initializer function wrapper with_logical_partitioning for introducing logical axis metadata into a model’s variables.
- flax.linen.Partitioned(value, names, mesh=None)[source]#
Wrapper for partitioning metadata.
Partitioned
is used to extend variables with partitioning information required forjax.experimental.pjit
.The easiest way to define Partitioned variables is by using the
with_partitioning
wrapper around the variable initializer.Example:
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determinte the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
Partitioned
values can gain additional axes when using transformations likenn.vmap
andnn.scan
. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- flax.linen.with_partitioning(fn, names, mesh=None)[source]#
Wraps a function’s return value with Partitioned.
Example:
kernel_init = with_partitioning( nn.initializers.lecun_normal, (None, "data")) partitioned_dense = nn.Dense(features, kernel_init=kernel_init)
- Parameters
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
Partitioned
.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
- Returns
A function wrapping
fn
that will return an instance ofPartitioned
.
- flax.linen.get_partition_spec(tree)[source]#
Extracts a PartitionSpec tree from a PyTree containing
Partitioned
values.
- flax.linen.get_sharding(tree, mesh)[source]#
Extracts a jax.sharding tree from a PyTree containing
Partitioned
values and a mesh.
- flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Optional[str], ...], mesh: Optional[jax._src.mesh.Mesh] = None, rules: Optional[Sequence[Tuple[str, Union[str, Tuple[str], NoneType]]]] = None)[source]#
- flax.linen.logical_axis_rules(rules)[source]#
Context manager for setting the logical to mesh axis bindings.
- flax.linen.set_logical_axis_rules(rules)[source]#
Sets the global logical axis to mesh axis binding.
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
Compute layout for an array.
The rules are in order of precedence, and consist of pairs: (ArrayDimensionName, MeshDimensionName), meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused).
A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh.
- For example, given an array with
array_dim_names = (‘batch’, ‘length’, ‘heads’, ‘features’)
- and the layout rules are:
- rules = ((‘batch’, ‘X’),
(‘features’, ‘X’), (‘heads’, ‘Y’), (‘batch’, ‘Z’))
then this function will return
PartitionSpec(‘X’, None, ‘Y’, None)
- Parameters
array_dim_names – Tuple of array dimension names or None.
rules – Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the axis_rules function.
- Returns
PartitionSpec for the parameter.
- flax.linen.logical_to_mesh(tree, rules=None)[source]#
Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.
- flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#
Convert pytrees of logical PartitionSpecs to shardings.
- flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#
Version of pjit’s with_sharding_constraint that uses logical axis names.
- flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#
Wraps a function’s return value with LogicallyPartitioned.
Example:
kernel_init = with_logical_partitioning( nn.initializers.lecun_normal, (None, "data")) partitioned_dense = nn.Dense(features, kernel_init=kernel_init)
- Parameters
fn – The function to be wrapped. Typically this is an initializer.
names – The logical axis passed to
LogicallyPartitioned
.mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.
rules – Optional logical to mesh rules use. If None, the global rules are used if available.
- Returns
A function wrapping
fn
that will return an instance ofLogicallyPartitioned
.
Summary
|
Wrapper for partitioning metadata. |
|
Wraps a function's return value with Partitioned. |
|
Extracts a PartitionSpec tree from a PyTree containing |
|
Extracts a jax.sharding tree from a PyTree containing |
|
|
|
Context manager for setting the logical to mesh axis bindings. |
|
Sets the global logical axis to mesh axis binding. |
Returns the global logical axis to mesh axis binding. |
|
|
Compute layout for an array. |
|
Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs. |
|
Convert pytrees of logical PartitionSpecs to shardings. |
|
Version of pjit's with_sharding_constraint that uses logical axis names. |
|
Wraps a function's return value with LogicallyPartitioned. |