flax.linen.LogicallyPartitioned#
- class flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Optional[str], ...], mesh: Optional[jax._src.mesh.Mesh] = None, rules: Optional[Sequence[Tuple[str, Union[str, Tuple[str], NoneType]]]] = None)[source]#
- __init__(value, names, mesh=None, rules=None)#
Methods
__init__
(value, names[, mesh, rules])add_axis
(index, params)Adds a new axis to the axis metadata.
get_partition_spec
()Returns the
Partitionspec
for this partitioned value.get_sharding
(mesh)Returns the
NamedSharding
for this partitioned value.remove_axis
(index, params)Removes an axis from the axis metadata.
replace
(**updates)"Returns a new object replacing the specified fields with new values.
replace_boxed
(val)Replaces the boxed value with the provided value.
unbox
([apply_constraint])Returns the wrapped value with the partitioning constraint applied.
Attributes
mesh
rules