transforms#
- class flax.experimental.nnx.Remat(module_constructor, prevent_cse=True, static_argnums=(), policy=None)[source]#
- class flax.experimental.nnx.Scan(module_constructor, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=0, in_axes_kwargs=0, out_axes=0, carry_argnum=1, state_axes=FrozenDict({Ellipsis: 0}), split_rngs=Ellipsis, transform_metadata=FrozenDict({}), scan_output=True)[source]#
- class flax.experimental.nnx.Vmap(module_constructor, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, in_axes_kwargs=0, state_axes=FrozenDict({Ellipsis: 0}), split_rngs=Ellipsis, transform_metadata=FrozenDict({}))[source]#
- flax.experimental.nnx.grad(f, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=(), *, wrt=<class 'flax.experimental.nnx.nnx.variables.Param'>)[source]#
Lifted version of
jax.grad
that can handle Modules / graph nodes as arguments.The differentiable state of each graph node is defined by the wrt filter, which by default is set to nnx.Param. Internally the
State
of graph nodes is extracted, filtered according to wrt filter, and passed to the underlyingjax.grad
function. The gradients of graph nodes are of typeState
.Example:
>>> from flax.experimental import nnx ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn, wrt=nnx.Param) ... >>> grads = grad_fn(m, x, y) >>> jax.tree_util.tree_map(jnp.shape, grads) State({ 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) })
- Parameters
fun – Function to be differentiated. Its arguments at positions specified by
argnums
should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified byargnums
must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'
is a named batch axis,grad(f, reduce_axes=('batch',))
will create a function that computes the total gradient whilegrad(f)
will create one that computes the per-example gradient.wrt – Optional, filterlib.Filter. Filter to extract the differentiable state of each graph node. Default is nnx.Param.
- flax.experimental.nnx.jit(fun, *, in_shardings=<object object>, out_shardings=<object object>, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None, donate_state=False, constrain_state=False)[source]#
Lifted version of
jax.jit
that can handle Modules / graph nodes as arguments.- Parameters
fun –
Function to be jitted.
fun
should be a pure function, as side-effects may only be executed once.The arguments and return value of
fun
should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.JAX keeps a weak reference to
fun
for use as a compilation cache key, so the objectfun
must be weakly-referenceable. MostCallable
objects will already satisfy this requirement.in_shardings –
Pytree of structure matching that of arguments to
fun
, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.The
in_shardings
argument is optional. JAX will infer the shardings from the inputjax.Array
’s and defaults to replicating the input if the sharding cannot be inferred.- The valid resource assignment specifications are:
XLACompatibleSharding
, which will decide how the valuewill be partitioned. With this, using a mesh context manager is not required.
None
, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.
out_shardings –
Like
in_shardings
, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.The
out_shardings
argument is optional. If not specified,jax.jit()
will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.static_argnums –
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnums
norstatic_argnames
is provided, no arguments are treated as static. Ifstatic_argnums
is not provided butstatic_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond tostatic_argnames
(or vice versa). If bothstatic_argnums
andstatic_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherstatic_argnums
orstatic_argnames
will be treated as static.static_argnames – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.donate_argnums –
Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.
If neither
donate_argnums
nordonate_argnames
is provided, no arguments are donated. Ifdonate_argnums
is not provided butdonate_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond todonate_argnames
(or vice versa). If bothdonate_argnums
anddonate_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherdonate_argnums
ordonate_argnames
will be donated.For more details on buffer donation see the FAQ.
donate_argnames – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on
donate_argnums
for details. If not provided butdonate_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.keep_unused – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:
'cpu'
,'gpu'
, or'tpu'
.inline – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.
donate_state – Optional, bool. If True, the object state of the graph node’s state will be donated to the computation. Default False.
constrain_state – Optional, bool or callable. If True, the object state of the graph node’s state will be constrained to the partition specified by the graph node’s partition spec as computed by
nnx.spmd.get_partition_spec()
. If a callable, the object State will passed to the callable which must return the constrained object State. If False, the object state will not be constrained. Default False.
- Returns
A wrapped version of
fun
, set up for just-in-time compilation.
- flax.experimental.nnx.scan(f, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=0, in_axes_kwargs=0, out_axes=0, carry_argnum=0, state_axes=FrozenDict({Ellipsis: 0}), split_rngs=Ellipsis, transform_metadata=FrozenDict({}), scan_output=True)[source]#