flax.linen.LogicallyPartitioned#

class flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Optional[str], ...], mesh: Optional[jax._src.mesh.Mesh] = None, rules: Optional[Sequence[Tuple[str, Union[str, Tuple[str], NoneType]]]] = None)[source]#
__init__(value, names, mesh=None, rules=None)#

Methods

__init__(value, names[, mesh, rules])

add_axis(index, params)

Adds a new axis to the axis metadata.

get_partition_spec()

Returns the Partitionspec for this partitioned value.

get_sharding(mesh)

Returns the NamedSharding for this partitioned value.

remove_axis(index, params)

Removes an axis from the axis metadata.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

replace_boxed(val)

Replaces the boxed value with the provided value.

unbox([apply_constraint])

Returns the wrapped value with the partitioning constraint applied.

Attributes

mesh

rules