rnglib#
- class flax.nnx.Rngs(*args, **kwargs)[source]#
NNX rng container class. To instantiate the
Rngs
, pass in an integer, specifying the starting seed.Rngs
can have different “streams”, allowing the user to generate different rng keys. For example, to generate a key for theparams
anddropout
stream:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rng1 = nnx.Rngs(0, params=1) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() != rng2.dropout()
Because we passed in
params=1
, the starting seed forparams
is1
, whereas the starting seed fordropout
defaults to the0
we passed in, since we didn’t specify a seed fordropout
. If we didn’t specify a seed forparams
, then both streams will default to using the0
we passed in:>>> rng1 = nnx.Rngs(0) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() == rng2.dropout()
The
Rngs
container class contains a separate counter for each stream. Every time the stream is called to generate a new rng key, the counter increments by1
. To generate a new rng key, we fold in the counter value for the current rng stream into its corresponding starting seed. If we try to generate an rng key for a stream we did not specify on instantiation, then thedefault
stream is used (i.e. the first positional argument passed toRngs
during instantiation is thedefault
starting seed):>>> rng1 = nnx.Rngs(100, params=42) >>> # `params` stream starting seed is 42, counter is 0 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0) >>> # `dropout` stream starting seed is defaulted to 100, counter is 0 >>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0) >>> # empty stream starting seed is defaulted to 100, counter is 1 >>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1) >>> # `params` stream starting seed is 42, counter is 1 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)
Let’s see an example of using
Rngs
in aModule
and verifying the output by manually threading theRngs
:>>> class Model(nnx.Module): ... def __init__(self, rngs): ... # Linear uses the `params` stream twice for kernel and bias ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... # Dropout uses the `dropout` stream once ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) >>> def assert_same(x, rng_seed, **rng_kwargs): ... model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs)) ... out = model(x) ... ... # manual forward propagation ... rngs = nnx.Rngs(rng_seed, **rng_kwargs) ... kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3)) ... assert (model.linear.kernel.value==kernel).all() ... bias = nnx.initializers.zeros_init()(rngs.params(), (3,)) ... assert (model.linear.bias.value==bias).all() ... mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3)) ... # dropout scales the output proportional to the dropout rate ... manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5 ... assert (out == manual_out).all() >>> x = jnp.ones((1, 2)) >>> assert_same(x, 0) >>> assert_same(x, 0, params=1) >>> assert_same(x, 0, params=1, dropout=2)
- __init__(default=None, /, **rngs)[source]#
- Parameters
default – the starting seed for the
default
stream. Any key generated from a stream that isn’t specified in the**rngs
key-word arguments will default to using this starting seed.**rngs – optional key-word arguments to specify starting seeds for different rng streams. The key-word is the stream name and its value is the corresponding starting seed for that stream.
- flax.nnx.reseed(node, /, **stream_keys)[source]#
Update the keys of the specified RNG streams with new keys.
- Parameters
node – the node to reseed the RNG streams in.
**stream_keys – a mapping of stream names to new keys. The keys can be either integers or jax arrays. If an integer is passed in, then the key will be generated using
jax.random.key
.
- Raises
ValueError – if an existing stream key is not a scalar.
Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)