bridge#
- class flax.nnx.bridge.ToNNX(*args, **kwargs)[source]#
A wrapper to turn any Linen module into an NNX module.
The result NNX module can be used standalone with all NNX APIs, or as a submodule of another NNX module.
Since Linen module initialization requires a sample input, you need to call lazy_init with an argument to initialize the variables.
Example:
>>> from flax import linen as nn, nnx >>> import jax >>> linen_module = nn.Dense(features=64) >>> x = jax.numpy.ones((1, 32)) >>> # Like Linen init(), initialize with a sample input >>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) >>> # Like Linen apply(), but using NNX's direct call method >>> y = model(x) >>> model.kernel.shape (32, 64)
- Parameters
module – The Linen Module instance.
rngs – The nnx.Rngs instance being passed to any NNX module.
- Returns
A stateful NNX module that behaves the same as the wrapped Linen module.
Methods
lazy_init
(*args, **kwargs)A shortcut of calling nnx.bridge.lazy_init() upon this module.
- class flax.nnx.bridge.ToLinen(nnx_class, args=(), kwargs=<factory>, skip_rng=False, metadata_type=<class 'flax.nnx.bridge.variables.NNXMeta'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A wrapper to turn any NNX module into a Linen module.
The result Linen module can be used standalone with all Linen APIs, or as a submodule of another Linen module.
Since NNX modules are stateful and owns the state, we only create it once during init time, and will track its state and static data as separate variables.
Example:
>>> from flax import linen as nn, nnx >>> import jax >>> model = nnx.bridge.ToLinen(nnx.Linear, args=(32, 64)) >>> x = jax.numpy.ones((1, 32)) >>> y, variables = model.init_with_output(jax.random.key(0), x) >>> y.shape (1, 64) >>> variables['params']['kernel'].shape (32, 64) >>> # The static GraphDef of the underlying NNX module >>> variables.keys() dict_keys(['nnx', 'params']) >>> type(variables['nnx']['graphdef']) <class 'flax.nnx.graph.NodeDef'>
- Parameters
nnx_class – The NNX Module class (not instance!).
args – The arguments that normally would be passed in to create the NNX module.
kwargs – The keyword arguments that normally would be passed in to create the NNX module.
skip_rng – True if this NNX module doesn’t need rngs arg during initialization (not common).
- Returns
A stateful NNX module that behaves the same as the wrapped Linen module.
Methods
- flax.nnx.bridge.to_linen(nnx_class, *args, name=None, **kwargs)[source]#
Shortcut of nnx.bridge.ToLinen if user is not changing any of its default fields.
- class flax.nnx.bridge.NNXMeta(var_type, value, metadata)[source]#
Default Flax metadata class for nnx.VariableState.
- __call__(**kwargs)#
Call self as a function.
- add_axis(index, params)[source]#
Adds a new axis to the axis metadata.
Note that add_axis and remove_axis should act as each other’s inverse (meaning:
x.add_axis(i, p).remove_axis(i, p) == x
)- Parameters
index – The position at which the new axis will be inserted
params – An arbitrary dictionary of parameters passed by the transformation that introduces the new axis (e.g.:
nn.scan
ornn.vmap
). The user passes this dictionary as the metadata_param argument to the transformation.
- Returns
A new instance of the same type as self and with the same
unbox
content with updated axis metadata.
- remove_axis(index, params)[source]#
Removes an axis from the axis metadata.
Note that add_axis and remove_axis should act as each other’s inverse (meaning:
x.remove_axis(i, p).add_axis(i, p) == x
)- Parameters
index – The position of the axis that is to be removed
params – An arbitrary dictionary of parameters passed by the transformation that introduced the axis (e.g.:
nn.scan
ornn.vmap
). The user passes this dictionary as the metadata_param argument to the transformation.
- Returns
A new instance of the same type as self and with the same
unbox
content with updated axis metadata.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- replace_boxed(val)[source]#
Replaces the boxed value with the provided value.
- Parameters
val – The new value to be boxed by this AxisMetadata wrapper
- Returns
A new instance of the same type as self with val as the new
unbox
content
- unbox()[source]#
Returns the content of the AxisMetadata box.
Note that unlike
meta.unbox
the unbox call should not recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata.In practise, AxisMetadata subclasses should be registered as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this node should correspond to the value returned by unbox.
- Returns
The unboxed value.
Methods
add_axis
(index, params)Adds a new axis to the axis metadata.
Returns the
Partitionspec
for this partitioned value.remove_axis
(index, params)Removes an axis from the axis metadata.
replace
(**updates)"Returns a new object replacing the specified fields with new values.
replace_boxed
(val)Replaces the boxed value with the provided value.
unbox
()Returns the content of the AxisMetadata box.