flax.linen.Partitioned

flax.linen.Partitioned#

class flax.linen.Partitioned(value, names, mesh=None)[source]#

Wrapper for partitioning metadata.

Partitioned is used to extend variables with partitioning information required for jax.experimental.pjit.

The easiest way to define Partitioned variables is by using the with_partitioning wrapper around the variable initializer.

Example:

class MLP(nn.Module):
  hidden_size: int
  @nn.compact
  def __call__(self, x):
    ki = nn.linear.default_kernel_init
    h = nn.Dense(
        self.hidden_size,
        kernel_init=nn.with_partitioning(ki, ('data', 'model')))(x)
    h = nn.relu(h)
    return nn.Dense(
        x.shape[-1],
        kernel_init=nn.with_partitioning(ki, ('model', 'data')))(h)
mlp = MLP(4096)
x = jnp.ones((8 * 1024, 1024))
# use eval_shape to get the Partitioned instances for the variables.
# this way we can determine the PartitionSpecs for the init variables
# before we call the init fn.
var_spec = nn.get_partition_spec(
    jax.eval_shape(mlp.init, random.key(0), x))
init_fn = mesh(pjit(mlp.init,
                    (None, PartitionSpec("data", "model")), var_spec))
variables = init_fn(random.key(0), x)
apply_fn = mesh(pjit(
    mlp.apply,
    (var_spec, PartitionSpec("data", "model")),
     PartitionSpec("data", "model")))
apply_fn(variables, x)

Partitioned values can gain additional axes when using transformations like nn.vmap and nn.scan. In this case you can specify the name of the new axis with the metadata_params args in vmap/scan:

class Model(nn.Module):
@nn.compact
def __call__(self, x):
  def body(mdl, c):
    c = MLP(4096)(c)
    return c, ()
  c, _ = nn.scan(
      body, variable_axes={"params": 0}, split_rngs={"params": 0}, length=8,
      metadata_params={nn.meta.PARTITION_NAME: "layers"})(self, x)
  return c
__init__(value, names, mesh=None)#

Methods

__init__(value, names[, mesh])

add_axis(index, params)

Adds a new axis to the axis metadata.

get_partition_spec()

Returns the Partitionspec for this partitioned value.

get_sharding(mesh)

Returns the NamedSharding for this partitioned value.

remove_axis(index, params)

Removes an axis from the axis metadata.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

replace_boxed(val)

Replaces the boxed value with the provided value.

unbox([apply_constraint])

Returns the wrapped value with the partitioning applied as a sharding constraint.

Attributes

mesh

value

names