flax.linen.Partitioned#
- class flax.linen.Partitioned(value, names, mesh=None)[source]#
Wrapper for partitioning metadata.
Partitioned
is used to extend variables with partitioning information required forjax.experimental.pjit
.The easiest way to define Partitioned variables is by using the
with_partitioning
wrapper around the variable initializer.Example:
class MLP(nn.Module): hidden_size: int @nn.compact def __call__(self, x): ki = nn.linear.default_kernel_init h = nn.Dense( self.hidden_size, kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x) h = nn.relu(h) return nn.Dense( x.shape[-1], kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h) mlp = MLP(4096) x = jnp.ones((8 * 1024, 1024)) # use eval_shape to get the Partitioned instances for the variables. # this way we can determinte the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), PartitionSpec("data", "model"))) apply_fn(variables, x)
Partitioned
values can gain additional axes when using transformations likenn.vmap
andnn.scan
. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:class Model(nn.Module): @nn.compact def __call__(self, x): def body(mdl, c): c = MLP(4096)(c) return c, () c, _ = nn.scan( body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8, metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x) return c
- __init__(value, names, mesh=None)#
Methods
__init__
(value, names[, mesh])add_axis
(index, params)Adds a new axis to the axis metadata.
get_partition_spec
()Returns the
Partitionspec
for this partitioned value.get_sharding
(mesh)Returns the
NamedSharding
for this partitioned value.remove_axis
(index, params)Removes an axis from the axis metadata.
replace
(**updates)"Returns a new object replacing the specified fields with new values.
replace_boxed
(val)Replaces the boxed value with the provided value.
unbox
([apply_constraint])Returns the wrapped value with the partitioning applied as a sharding constraint.
Attributes
mesh
value
names