transforms#

flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#

Object-aware version of jax.grad that can handle Modules / graph nodes as arguments.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'bias': VariableState(
    type=Param,
    value=(3,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})

By default, NNX objects are differentiated with respect to all their nnx.Param Variables. You can specify which substates are differentiable by passing a DiffState object to the argnums argument. For example, if you want to differentiate only the kernel attribute of the Linear class, you can use the PathContains filter:

>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
...
>>> kernel_attribute = nnx.PathContains('kernel')
>>> diff_state = nnx.DiffState(0, kernel_attribute)
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn, argnums=diff_state)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'kernel': VariableState(
    type=Param,
    value=(2, 3)
  )
})

For more information on how to create custom filters, see Using Filters guide.

Parameters
  • fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by argnums must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

  • argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

  • allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.

flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]#

Lifted version of jax.jit that can handle Modules / graph nodes as arguments.

Parameters
  • fun

    Function to be jitted. fun should be a pure function, as side-effects may only be executed once.

    The arguments and return value of fun should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

    JAX keeps a weak reference to fun for use as a compilation cache key, so the object fun must be weakly-referenceable. Most Callable objects will already satisfy this requirement.

  • in_shardings

    Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

    The in_shardings argument is optional. JAX will infer the shardings from the input jax.Array’s and defaults to replicating the input if the sharding cannot be inferred.

    The valid resource assignment specifications are:
    • Sharding, which will decide how the value

      will be partitioned. With this, using a mesh context manager is not required.

    • None, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

    The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.

  • out_shardings

    Like in_shardings, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.

    The out_shardings argument is optional. If not specified, jax.jit() will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.

  • static_argnums

    An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

    Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.

    If neither static_argnums nor static_argnames is provided, no arguments are treated as static. If static_argnums is not provided but static_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to static_argnames (or vice versa). If both static_argnums and static_argnames are provided, inspect.signature is not used, and only actual parameters listed in either static_argnums or static_argnames will be treated as static.

  • static_argnames – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.

    If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.

    For more details on buffer donation see the FAQ.

  • donate_argnames – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on donate_argnums for details. If not provided but donate_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • keep_unused – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • inline – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.

Returns

A wrapped version of fun, set up for just-in-time compilation.

flax.nnx.shard_map(f=<class 'flax.typing.Missing'>, *, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[source]#

Lifted version of jax.experimental.shard_map.shard_map that can handle Modules / graph nodes as arguments.

Simple data parallel example:

import jax
import jax.numpy as jnp
from flax import nnx
from jax.sharding import PartitionSpec as P

mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))

m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

@nnx.shard_map(
  mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
)
def f(m, x):
  return m(x)

y = f(m, x)

jax.debug.visualize_array_sharding(y)

Notice that here we simply used some PartitionSpec to define the spec the the whole model and data. This works for simple cases but if we need to assign different PartitionSpec to different parts of the model we need to use StateSharding and create some filters that allow us to target specific parts of the model. Here’s an example of how to do tensor parallelism for a simple MLP block using StateSharding and filters:

mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))

class MLP(nnx.Module):
  def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
    self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)

  def __call__(self, x):
    return self.linear2(jax.nn.relu(self.linear1(x)))

m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

def path_ends_with(*path_suffix): # custom filter
  return lambda path, value: path[-len(path_suffix):] == path_suffix

model_spec = nnx.StateSharding({
  path_ends_with('linear1', 'kernel'): P(None, 'model'),
  path_ends_with('linear2', 'kernel'): P('model', None),
})

@nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None))
def f(m, x):
  y = m(x)
  return jax.lax.psum(y, 'model')

y = f(m, x)

jax.debug.visualize_array_sharding(m.linear1.kernel.value)
jax.debug.visualize_array_sharding(m.linear2.kernel.value)

Alternatively, a State object with the exact PartitionSpec for each state then you can be passed to StateSharding:

mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))

class MLP(nnx.Module):
  def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
    self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)

  def __call__(self, x):
    return self.linear2(jax.nn.relu(self.linear1(x)))

m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

model_spec = nnx.State(
  {
    'linear1': {'kernel': P(None, 'model')},
    'linear2': {'kernel': P('model', None)},
  }
)

@nnx.shard_map(
  mesh=mesh,
  in_specs=(nnx.StateSharding(model_spec), P(None)),
  out_specs=P(None),
)
def f(m, x):
  y = m(x)
  return jax.lax.psum(y, 'model')

y = f(m, x)

jax.debug.visualize_array_sharding(m.linear1.kernel.value)
jax.debug.visualize_array_sharding(m.linear2.kernel.value)

Here model_spec was created manually but you can also automate this process by using nnx.get_partition_spec to automatically create it for you (see Scale up on multiple devices ).

Parameters
  • f – callable to be mapped. Each application of f, or “instance” of f, takes as input a shard of the mapped-over arguments and produces a shard of the output.

  • mesh – a jax.sharding.Mesh representing the array of devices over which to shard the data and on which to execute instances of f. The names of the Mesh can be used in collective communication operations in f. This is typically created by a utility function like jax.experimental.mesh_utils.create_device_mesh().

  • in_specs – a pytree with jax.sharding.PartitionSpec``or ``nnx.StateSharding (mapping substates to PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding, each PartitionSpec represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes of mesh. In each PartitionSpec, mentioning a mesh axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded.

  • out_specs – a pytree with jax.sharding.PartitionSpec or nnx.StateSharding (mapping substates to PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the output of ``f. Each PartitionSpec represents how the corresponding output shards should be concatenated. In each PartitionSpec, metioning a mesh axis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis. Not mentioning a mesh axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.

  • check_rep – If True (default) enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in out_specs are consistent with how the outputs of f are replicated. Must be set False if using a Pallas kernel in f.

  • auto – (experimental) an optional set of axis names from mesh over which we do not shard the data or map the function, but rather we allow the compiler to control sharding. These names cannot be used in in_specs, out_specs, or in communication collectives in f.

Returns

A callable that applies the input function f across data sharded according to the mesh and in_specs.

flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[source]#
flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[source]#
flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]#
flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[source]#

Reference-aware version of jax.vmap.

Parameters
  • f – Function to be mapped over additional axes.

  • in_axes – An integer, None, or sequence of values specifying which input array axes to map over (see jax.vmap). In addition to integers and None, StateAxes can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a Filter.

  • out_axes – An integer, None, or pytree indicating where the mapped axis should appear in the output (see jax.vmap).

  • axis_name – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.

  • axis_size – Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.

Returns

Batched/vectorized version of f with arguments that correspond to those of f, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of f, but with extra array axes at positions indicated by out_axes.

Example:

>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
...   return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
...   def __init__(self, num, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
...   return jnp.dot(x, model.w.value)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)

To control control how graph node substates are vectorized, StateAxes can be passed to in_axes and out_axes specifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:

>>> class Foo(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.arange(4))
...     self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
...   return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)
flax.nnx.eval_shape(f, *args, **kwargs)[source]#
A “lifted” version of jax.eval_shape

that can handle flax.nnx.Module / graph nodes as arguments.

Similar to jax.eval_shape, it computes the shape/dtype of a function f without

performing any floating point operations (FLOPs) which can be expensive. This can be useful for performing shape inference, for example.

flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[source]#

Reference aware version of jax.custom_vjp.

nnx.custom_vjp accepts Modules and other Flax NNX objects as arguments. The main difference with the JAX version is that, because Modules follow reference semantics, they propagate the State updates for the inputs as auxiliary outputs. This means that the incomming gradients in the bwd function will have the form (input_updates_g, out_g) where input_updates_g is the gradient updated state of the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated State term in input_updates_g, while all non-Module terms will appear as None. The shape of the tanget will be expected to have the same shape as the input, with State terms in place of the corresponding Module terms.

Example:

>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
...   def __init__(self, x, y):
...     self.x = nnx.Param(x)
...     self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
...   return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, sin_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g['x'].value = cos_x * out_g * m.y
...   m_g['y'].value = sin_x * out_g
...   return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
  'x': VariableState(
    type=Param,
    value=()
  ),
  'y': VariableState(
    type=Param,
    value=()
  )
})

Note that the State objects that represent Module terms on input_updates_g have the same shape as the State objects expected in the output tanget. This means that you can usually just copy them from input_updates_g and update them with their corresponding gradient values.

You can select which substates are differentiable (have a tangent) for Modules and other graph nodes by passing a DiffState to nondiff_argnums. For example, if you want to differentiate only the x attribute of the Foo class, you can do the following:

>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y  # type: ignore

>>> def f_fwd(m: Foo):
...   y = f(m)
...   res = (jnp.cos(m.x), m)  # type: ignore
...   return y, res
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g.x.value = cos_x * out_g * m.y
...   del m_g['y'] # y is not differentiable
...   return (m_g,)

>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
  'x': VariableState(
    type=Param,
    value=()
  )
})

Note that grad cannot calculate gradients for states that don’t have a tangent defined by custom_vjp, in the example above we reuse the same x_attribute filter to keep custom_vjp and grad in sync.

Parameters
  • fun – Callable base function.

  • nondiff_argnums – Tuple of integers or DiffState objects specifying the argument indices that are not differentiated. By default all arguments are differentiated. Integers cannot be used to mark graph nodes such as Modules as non-differentiable, in this case use a DiffState object. DiffState objects define the set of differentiable substates, contrary to what the name of this argument suggests, this is done for compatibility with grad.

flax.nnx.cond(pred, true_fun, false_fun, *operands, **kwargs)[source]#
flax.nnx.switch(index, branches, *operands)[source]#
flax.nnx.while_loop(cond_fun, body_fun, init_val)[source]#

A Flax NNX transformation of jax.lax.while_loop.

Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of init_val inside body_fun.

Example:

>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
...   module, x, count = input
...   return module, module(x), count - 1.0

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
...   lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
Parameters
  • cond_fun – A function for the continue condition of the while loop, taking a single input of type T and outputting a boolean.

  • body_fun – A function that takes an input of type T and outputs an T. Note that both data and modules of T must have the same reference structure between inputs and outputs.

  • init_val – The initial input for cond_fun and body_fun. Must be of type T.

flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[source]#

A Flax NNX transformation of jax.lax.fori_loop.

Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of init_val inside body_fun.

Example:

>>> import jax
>>> from flax import nnx

>>> def fwd_fn(i, input):
...   m, x = input
...   m.kernel.value = jnp.identity(10) * i
...   return m, m(x)

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
Parameters
  • lower – An integer representing the loop index lower bound (inclusive).

  • upper – An integer representing the loop index upper bound (exclusive).

  • body_fun – a function that takes an input of type T and outputs an T. Note that both data and modules of T must have the same reference structure between inputs and outputs.

  • init_val – the initial input for body_fun. Must be of type T.

  • unroll – An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. unroll=True) or left completely unrolled (i.e. unroll=False). This argument is only applicable if the loop bounds are statically known.

Returns

A loop value from the final iteration, of type T.