flax.linen.LogicallyPartitioned#

class flax.linen.LogicallyPartitioned(value: Any, names: Tuple[Union[str, NoneType], ...])[source]#
__init__(value, names)#

Methods

__init__(value, names)

add_axis(index, params)

Adds a new axis to the axis metadata.

get_partition_spec()

Returns the Partitionspec for this partitioned value.

remove_axis(index, params)

Removes an axis from the axis metadata.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

replace_boxed(val)

Replaces the boxed value with the provided value.

unbox([apply_constraint])

Returns the wrapped value with the partitioning constraint applied.

Attributes