bridge#

class flax.nnx.bridge.ToNNX(*args, **kwargs)[source]#

A wrapper to turn any Linen module into an NNX module.

The result NNX module can be used standalone with all NNX APIs, or as a submodule of another NNX module.

Since Linen module initialization requires a sample input, you need to call lazy_init with an argument to initialize the variables.

Example:

>>> from flax import linen as nn, nnx
>>> import jax
>>> linen_module = nn.Dense(features=64)
>>> x = jax.numpy.ones((1, 32))
>>> # Like Linen init(), initialize with a sample input
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply(), but using NNX's direct call method
>>> y = model(x)
>>> model.kernel.shape
(32, 64)
Parameters
  • module – The Linen Module instance.

  • rngs – The nnx.Rngs instance being passed to any NNX module.

Returns

A stateful NNX module that behaves the same as the wrapped Linen module.

__call__(*args, rngs=None, method=None, **kwargs)[source]#

Call self as a function.

lazy_init(*args, **kwargs)[source]#

A shortcut of calling nnx.bridge.lazy_init() upon this module.

Methods

lazy_init(*args, **kwargs)

A shortcut of calling nnx.bridge.lazy_init() upon this module.

class flax.nnx.bridge.ToLinen(nnx_class, args=(), kwargs=<factory>, skip_rng=False, metadata_type=<class 'flax.nnx.bridge.variables.NNXMeta'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A wrapper to turn any NNX module into a Linen module.

The result Linen module can be used standalone with all Linen APIs, or as a submodule of another Linen module.

Since NNX modules are stateful and owns the state, we only create it once during init time, and will track its state and static data as separate variables.

Example:

>>> from flax import linen as nn, nnx
>>> import jax
>>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64))
>>> x = jax.numpy.ones((1, 32))
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
dict_keys(['nnx', 'params'])
>>> type(variables['nnx']['graphdef'])
<class 'flax.nnx.graph.NodeDef'>
Parameters
  • nnx_class – The NNX Module class (not instance!).

  • args – The arguments that normally would be passed in to create the NNX module.

  • kwargs – The keyword arguments that normally would be passed in to create the NNX module.

  • skip_rng – True if this NNX module doesn’t need rngs arg during initialization (not common).

Returns

A stateful NNX module that behaves the same as the wrapped Linen module.

__call__(*args, **kwargs)[source]#

Call self as a function.

Methods

flax.nnx.bridge.to_linen(nnx_class, *args, name=None, **kwargs)[source]#

Shortcut of nnx.bridge.ToLinen if user is not changing any of its default fields.

class flax.nnx.bridge.NNXMeta(var_type, value, metadata)[source]#

Default Flax metadata class for nnx.VariableState.

__call__(**kwargs)#

Call self as a function.

add_axis(index, params)[source]#

Adds a new axis to the axis metadata.

Note that add_axis and remove_axis should act as each other’s inverse (meaning: x.add_axis(i, p).remove_axis(i, p) == x)

Parameters
  • index – The position at which the new axis will be inserted

  • params – An arbitrary dictionary of parameters passed by the transformation that introduces the new axis (e.g.: nn.scan or nn.vmap). The user passes this dictionary as the metadata_param argument to the transformation.

Returns

A new instance of the same type as self and with the same unbox content with updated axis metadata.

get_partition_spec()[source]#

Returns the Partitionspec for this partitioned value.

remove_axis(index, params)[source]#

Removes an axis from the axis metadata.

Note that add_axis and remove_axis should act as each other’s inverse (meaning: x.remove_axis(i, p).add_axis(i, p) == x)

Parameters
  • index – The position of the axis that is to be removed

  • params – An arbitrary dictionary of parameters passed by the transformation that introduced the axis (e.g.: nn.scan or nn.vmap). The user passes this dictionary as the metadata_param argument to the transformation.

Returns

A new instance of the same type as self and with the same unbox content with updated axis metadata.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

replace_boxed(val)[source]#

Replaces the boxed value with the provided value.

Parameters

val – The new value to be boxed by this AxisMetadata wrapper

Returns

A new instance of the same type as self with val as the new unbox content

to_nnx_variable()[source]#
unbox()[source]#

Returns the content of the AxisMetadata box.

Note that unlike meta.unbox the unbox call should not recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata.

In practise, AxisMetadata subclasses should be registered as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this node should correspond to the value returned by unbox.

Returns

The unboxed value.

Methods

add_axis(index, params)

Adds a new axis to the axis metadata.

get_partition_spec()

Returns the Partitionspec for this partitioned value.

remove_axis(index, params)

Removes an axis from the axis metadata.

replace(**updates)

"Returns a new object replacing the specified fields with new values.

replace_boxed(val)

Replaces the boxed value with the provided value.

to_nnx_variable()

unbox()

Returns the content of the AxisMetadata box.