rnglib#
- class flax.nnx.Rngs(self, default=None, **rngs)[source]#
NNX rng container class. To instantiate the
Rngs, pass in an integer, specifying the starting seed.Rngscan have different “streams”, allowing the user to generate different rng keys. For example, to generate a key for theparamsanddropoutstream:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rng1 = nnx.Rngs(0, params=1) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() != rng2.dropout()
Because we passed in
params=1, the starting seed forparamsis1, whereas the starting seed fordropoutdefaults to the0we passed in, since we didn’t specify a seed fordropout. If we didn’t specify a seed forparams, then both streams will default to using the0we passed in:>>> rng1 = nnx.Rngs(0) >>> rng2 = nnx.Rngs(0) >>> assert rng1.params() == rng2.dropout()
The
Rngscontainer class contains a separate counter for each stream. Every time the stream is called to generate a new rng key, the counter increments by1. To generate a new rng key, we fold in the counter value for the current rng stream into its corresponding starting seed. If we try to generate an rng key for a stream we did not specify on instantiation, then thedefaultstream is used (i.e. the first positional argument passed toRngsduring instantiation is thedefaultstarting seed):>>> rng1 = nnx.Rngs(100, params=42) >>> # `params` stream starting seed is 42, counter is 0 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 0) >>> # `dropout` stream starting seed is defaulted to 100, counter is 0 >>> assert rng1.dropout() == jax.random.fold_in(jax.random.key(100), 0) >>> # empty stream starting seed is defaulted to 100, counter is 1 >>> assert rng1() == jax.random.fold_in(jax.random.key(100), 1) >>> # `params` stream starting seed is 42, counter is 1 >>> assert rng1.params() == jax.random.fold_in(jax.random.key(42), 1)
Let’s see an example of using
Rngsin aModuleand verifying the output by manually threading theRngs:>>> class Model(nnx.Module): ... def __init__(self, rngs): ... # Linear uses the `params` stream twice for kernel and bias ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... # Dropout uses the `dropout` stream once ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) >>> def assert_same(x, rng_seed, **rng_kwargs): ... model = Model(rngs=nnx.Rngs(rng_seed, **rng_kwargs)) ... out = model(x) ... ... # manual forward propagation ... rngs = nnx.Rngs(rng_seed, **rng_kwargs) ... kernel = nnx.initializers.lecun_normal()(rngs.params(), (2, 3)) ... assert (model.linear.kernel.value==kernel).all() ... bias = nnx.initializers.zeros_init()(rngs.params(), (3,)) ... assert (model.linear.bias.value==bias).all() ... mask = jax.random.bernoulli(rngs.dropout(), p=0.5, shape=(1, 3)) ... # dropout scales the output proportional to the dropout rate ... manual_out = mask * (jnp.dot(x, kernel) + bias) / 0.5 ... assert (out == manual_out).all() >>> x = jnp.ones((1, 2)) >>> assert_same(x, 0) >>> assert_same(x, 0, params=1) >>> assert_same(x, 0, params=1, dropout=2)
- __init__(default=None, **rngs)[source]#
- Parameters:
default – the starting seed for the
defaultstream. Any key generated from a stream that isn’t specified in the**rngskey-word arguments will default to using this starting seed.**rngs – optional key-word arguments to specify starting seeds for different rng streams. The key-word is the stream name and its value is the corresponding starting seed for that stream.
- flax.nnx.reseed(node, /, **stream_keys)[source]#
Update the keys of the specified RNG streams with new keys.
- Parameters:
node – the node to reseed the RNG streams in.
**stream_keys – a mapping of stream names to new keys. The keys can be either integers or jax arrays. If an integer is passed in, then the key will be generated using
jax.random.key.
- Raises:
ValueError – if an existing stream key is not a scalar.
Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool)