Start Date: 2025-09-12
FLIP PR: #4844
FLIP 4844: Variable eager sharding#
Summary#
Simplify the creation of sharded NNX models. When a sharding annotation is provided, all nnx.Variable creation will require a mesh context and automatically be sharded as annotated.
See GSPMD Guide for a comprehensive guide on how to make sharded NNX models.
Motivation#
To create a sharded model, user should only need to do this:
mesh = jax.make_mesh(((2, 4)), ("data", "model"))
with jax.set_mesh(mesh):
model = YourModelWithShardingAnnotations()
Instead of the current boilerplate combo of nnx.jit, nnx.get_partition_spec, with_sharding_constraint and nnx.update:
@nnx.jit
def create_sharded_model():
model = YourModelWithShardingAnnotations() # Unsharded at this moment.
state = nnx.state(model) # The model's state, a pure pytree.
pspecs = nnx.get_partition_spec(state) # Strip out the annotations from state.
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(model, sharded_state) # The model is sharded now!
return model
mesh = jax.make_mesh(((2, 4)), ("data", "model"))
with jax.set_mesh(mesh):
sharded_model = create_sharded_model()
Backward compatibility#
User can turn off this feature in two ways:
Global config flag: Run
flax.config.update('flax_always_shard_variable', False)before running any NNX model initialization.Variable-specific flag: Create a specific variable with metadata
eager_sharding=False, such as:nnx.Param(..., eager_sharding=False).
Flexibility options#
For debugging in a CPU environment, make a dummy mesh to run the model:
mesh = jax.make_mesh(((1, 1, 1)), ('your', 'axes', 'names'))
with jax.set_mesh(mesh):
...
For JAX explicit mode, remove the out_sharding= annotation on the nnx.Variable.
Implementation#
When an nnx.Variable is created, check for the metadata out_sharding, and if present, check if under a valid global mesh context of was supplied with a valid mesh. If no, throw error; if yes, call jax.lax.with_sharding_constraint to apply sharding constraint on the value.
Note that this only works in auto sharding mode. User should use JAX-level APIs to annotate shardings for explicit mode.