SPMD#

Utilities for working with jit and partitioned models.

This module introduces axis_rules, logical_to_mesh_axes, logical_to_mesh, with_logical_constraint for appyling jit sharding constraints in terms of “logical named axes” rather than jit’s default mesh axes.

Additionally the LogicallyPartitioned metadata wrapper is defined as well as the initializer function wrapper ``with_logical_partitioning``for introducing logical axis metadata into a model’s variables.

flax.linen.Partitioned(value, names, mesh=None)[source]#

Wrapper for partitioning metadata.

Partitioned is used to extend variables with partitioning information required for jax.experimental.pjit.

The easiest way to define Partitioned variables is by using the with_partitioning wrapper around the variable initializer.

Example:

class MLP(nn.Module):
  hidden_size: int
  @nn.compact
  def __call__(self, x):
    ki = nn.linear.default_kernel_init
    h = nn.Dense(
        self.hidden_size,
        kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x)
    h = nn.relu(h)
    return nn.Dense(
        x.shape[-1],
        kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h)
mlp = MLP(4096)
x = jnp.ones((8 * 1024, 1024))
# use eval_shape to get the Partitioned instances for the variables.
# this way we can determine the PartitionSpecs for the init variables
# before we call the init fn.
var_spec = nn.get_partition_spec(
    jax.eval_shape(mlp.init, random.key(0), x))
init_fn = mesh(pjit(mlp.init,
                    (None, PartitionSpec("data", "model")), var_spec))
variables = init_fn(random.key(0), x)
apply_fn = mesh(pjit(
    mlp.apply,
    (var_spec, PartitionSpec("data", "model")),
     PartitionSpec("data", "model")))
apply_fn(variables, x)

Partitioned values can gain additional axes when using transformations like nn.vmap and nn.scan. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:

class Model(nn.Module):
@nn.compact
def __call__(self, x):
  def body(mdl, c):
    c = MLP(4096)(c)
    return c, ()
  c, _ = nn.scan(
      body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8,
      metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x)
  return c
flax.linen.with_partitioning(fn, names, mesh=None)[source]#

Wraps a function’s return value with Partitioned.

Example:

>>> import flax.linen as nn
>>> kernel_init = nn.with_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
Parameters
  • fn – The function to be wrapped. Typically this is an initializer.

  • names – The logical axis passed to Partitioned.

  • mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.

Returns

A function wrapping fn that will return an instance of Partitioned.

flax.linen.get_partition_spec(tree)[source]#

Extracts a PartitionSpec tree from a PyTree containing Partitioned values.

flax.linen.get_sharding(tree, mesh)[source]#

Extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh.

flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Optional[str], ...], mesh: Optional[jax._src.mesh.Mesh] = None, rules: Optional[Sequence[Tuple[str, Union[str, Tuple[str], NoneType]]]] = None)[source]#
flax.linen.logical_axis_rules(rules)[source]#

Context manager for setting the logical to mesh axis bindings.

flax.linen.set_logical_axis_rules(rules)[source]#

Sets the global logical axis to mesh axis binding.

flax.linen.get_logical_axis_rules()[source]#

Returns the global logical axis to mesh axis binding.

flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#

Compute layout for an array.

The rules are in order of precedence, and consist of pairs: (ArrayDimensionName, MeshDimensionName), meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused).

A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh.

For example, given an array with:

array_dim_names = ('batch', 'length', 'heads', 'features')

and the layout rules are:

rules = (('batch', 'X'),
         ('features', 'X'),
         ('heads', 'Y'),
         ('batch', 'Z'))

then this function will return:

PartitionSpec('X', None, 'Y', None)
Parameters
  • array_dim_names – Tuple of array dimension names or None.

  • rules – Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the axis_rules function.

Returns

PartitionSpec for the parameter.

flax.linen.logical_to_mesh(tree, rules=None)[source]#

Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.

flax.linen.logical_to_mesh_sharding(tree, mesh, rules=None)[source]#

Convert pytrees of logical PartitionSpecs to shardings.

flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#

Version of jit’s with_sharding_constraint that uses logical axis names.

flax.linen.with_logical_partitioning(fn, names, mesh=None, rules=None)[source]#

Wraps a function’s return value with LogicallyPartitioned.

Example:

>>> import flax.linen as nn
>>> kernel_init = nn.with_logical_partitioning(
...     nn.initializers.lecun_normal(), (None, "data"))
>>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init)
Parameters
  • fn – The function to be wrapped. Typically this is an initializer.

  • names – The logical axis passed to LogicallyPartitioned.

  • mesh – The mesh to use for the partitioning. If None, the global mesh resource is used if available.

  • rules – Optional logical to mesh rules use. If None, the global rules are used if available.

Returns

A function wrapping fn that will return an instance of LogicallyPartitioned.

Summary

Partitioned(value, names[, mesh])

Wrapper for partitioning metadata.

with_partitioning(fn, names[, mesh])

Wraps a function's return value with Partitioned.

get_partition_spec(tree)

Extracts a PartitionSpec tree from a PyTree containing Partitioned values.

get_sharding(tree, mesh)

Extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh.

LogicallyPartitioned(value, names[, mesh, rules])

logical_axis_rules(rules)

Context manager for setting the logical to mesh axis bindings.

set_logical_axis_rules(rules)

Sets the global logical axis to mesh axis binding.

get_logical_axis_rules()

Returns the global logical axis to mesh axis binding.

logical_to_mesh_axes(array_dim_names[, rules])

Compute layout for an array.

logical_to_mesh(tree[, rules])

Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.

logical_to_mesh_sharding(tree, mesh[, rules])

Convert pytrees of logical PartitionSpecs to shardings.

with_logical_constraint(x, ...[, rules, ...])

Version of jit's with_sharding_constraint that uses logical axis names.

with_logical_partitioning(fn, names[, mesh, ...])

Wraps a function's return value with LogicallyPartitioned.