rnglib#

class flax.nnx.Rngs(*args, **kwargs)[source]#

NNX rng container class. To instantiate the Rngs, pass in an integer, specifying the starting seed. Rngs can have different “streams”, allowing the user to generate different rng keys. For example, to generate a key for the params and dropout stream:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rng1 = nnx.Rngs(0, params=1)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() != rng2.dropout()

Because we passed in params=1, the starting seed for params is 1, whereas the starting seed for dropout defaults to the 0 we passed in, since we didn’t specify a seed for dropout. If we didn’t specify a seed for params, then both streams will default to using the 0 we passed in:

>>> rng1 = nnx.Rngs(0)
>>> rng2 = nnx.Rngs(0)

>>> assert rng1.params() == rng2.dropout()

The Rngs container class contains a separate counter for each stream. Every time the stream is called to generate a new rng key, the counter increments by 1. To generate a new rng key, we fold in the counter value for the current rng stream into its corresponding starting seed. If we try to generate an rng key for a stream we did not specify on instantiation, then the default stream is used (i.e. the first positional argument passed to Rngs during instantiation is the default starting seed):

>>> rng1 = nnx.Rngs(100, params=42)
>>> # `params` stream starting seed is 42, counter is 0
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0)
>>> # `dropout` stream starting seed is defaulted to 100, counter is 0
>>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0)
>>> # empty stream starting seed is defaulted to 100, counter is 1
>>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1)
>>> # `params` stream starting seed is 42, counter is 1
>>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)

Let’s see an example of using Rngs in a Module and verifying the output by manually threading the Rngs:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     # Linear uses the `params` stream twice for kernel and bias
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     # Dropout uses the `dropout` stream once
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))

>>> def assert_same(x, rng_seed, **rng_kwargs):
...   model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs))
...   out = model(x)
...
...   # manual forward propagation
...   rngs = nnx.Rngs(rng_seed, **rng_kwargs)
...   kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3))
...   assert (model.linear.kernel.value==kernel).all()
...   bias = nnx.initializers.zeros_init()(rngs.params(), (3,))
...   assert (model.linear.bias.value==bias).all()
...   mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3))
...   # dropout scales the output proportional to the dropout rate
...   manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5
...   assert (out == manual_out).all()

>>> x = jnp.ones((1, 2))
>>> assert_same(x, 0)
>>> assert_same(x, 0, params=1)
>>> assert_same(x, 0, params=1, dropout=2)
__init__(default=None, /, **rngs)[source]#
Parameters
  • default – the starting seed for the default stream. Any key generated from a stream that isn’t specified in the **rngs key-word arguments will default to using this starting seed.

  • **rngs – optional key-word arguments to specify starting seeds for different rng streams. The key-word is the stream name and its value is the corresponding starting seed for that stream.

class flax.nnx.RngStream(*args: 'Any', **kwargs: 'Any')[source]#
flax.nnx.reseed(node, /, **stream_keys)[source]#

Update the keys of the specified RNG streams with new keys.

Parameters
  • node – the node to reseed the RNG streams in.

  • **stream_keys – a mapping of stream names to new keys. The keys can be either integers or jax arrays. If an integer is passed in, then the key will be generated using jax.random.key.

Raises

ValueError – if an existing stream key is not a scalar.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)