flax.linen.logical_to_mesh_axes#
- flax.linen.logical_to_mesh_axes(array_dim_names, rules=None)[source]#
Compute layout for an array.
The rules are in order of precedence, and consist of pairs: (ArrayDimensionName, MeshDimensionName), meaning that the given array dimension (if present and unused) should be sharded across the given mesh dimension (if present and unused).
A Layout of an Array is expressed as a tuple with one element for each dimension in the Array. The element is either None, or is the name of a mesh-dimension, meaning that this dimension of the array is sharded across this dimension of the mesh.
- For example, given an array with
array_dim_names = (‘batch’, ‘length’, ‘heads’, ‘features’)
- and the layout rules are:
- rules = ((‘batch’, ‘X’),
(‘features’, ‘X’), (‘heads’, ‘Y’), (‘batch’, ‘Z’))
then this function will return
PartitionSpec(‘X’, None, ‘Y’, None)
- Parameters
array_dim_names – Tuple of array dimension names or None.
rules – Optional logical to mesh rules override. Defaults to using the rules defined in the dynamic context set from the axis_rules function.
- Returns
PartitionSpec for the parameter.