rnglib#

class flax.nnx.Rngs(self, default=None, **rngs)[source]#

A small abstraction to manage RNG state.

Rngs allows the creation of RngStream which are used to easily generate new unique random keys on demand. An RngStream is a wrapper around a JAX random key, and a counter. Every time a key is requested, the counter is incremented and the key is generated from the seed key and the counter by using jax.random.fold_in.

To create an Rngs pass in an integer or jax.random.key to the constructor as a keyword argument with the name of the stream. The key will be used as the starting seed for the stream, and the counter will be initialized to zero. Then call the stream to get a key:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> rngs = nnx.Rngs(params=0, dropout=1)

>>> param_key1 = rngs.params()
>>> param_key2 = rngs.params()
>>> dropout_key1 = rngs.dropout()
>>> dropout_key2 = rngs.dropout()
...
>>> assert param_key1 != dropout_key1

Trying to generate a key for a stream that was not specified during construction will result in an error being raised:

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> try:
...   key = rngs.unkown_stream()
... except AttributeError as e:
...   print(e)
No RngStream named 'unkown_stream' found in Rngs.

The default stream can be created by passing in a key to the constructor without specifying a stream name. When the default stream is set the rngs object can be called directly to get a key, and calling streams that were not specified during construction will fallback to default:

>>> rngs = nnx.Rngs(0, params=1)
...
>>> key1 = rngs.default()       # uses 'default'
>>> key2 = rngs()               # uses 'default'
>>> key3 = rngs.params()        # uses 'params'
>>> key4 = rngs.dropout()       # uses 'default'
>>> key5 = rngs.unkown_stream() # uses 'default'
__init__(default=None, **rngs)[source]#
Parameters:
  • default – the starting seed for the default stream, defaults to None.

  • **rngs – keyword arguments specifying the starting seed for each stream. The key can be an integer or a jax.random.key.

class flax.nnx.RngStream(self, key, *, tag)[source]#
ball(d, p=2, shape=(), dtype=None, *, out_sharding=None)#

Sample uniformly from the unit Lp ball.

Reference: https://arxiv.org/abs/math/0503650.

Parameters:
  • key – a PRNG key used as the random key.

  • d – a nonnegative int representing the dimensionality of the ball.

  • p – a float representing the p parameter of the Lp norm.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding – optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array of shape (*shape, d) and specified dtype.

bernoulli(p=0.5, shape=None, mode='low', *, out_sharding=None)#

Sample Bernoulli random values with given shape and mean.

The values are distributed according to the probability mass function:

\[f(k; p) = p^k(1 - p)^{1 - k}\]

where \(k \in \{0, 1\}\) and \(0 \le p \le 1\).

Parameters:
  • key – a PRNG key used as the random key.

  • p – optional, a float or array of floats for the mean of the random variables. Must be broadcast-compatible with shape. Default 0.5.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with p.shape. The default (None) produces a result shape equal to p.shape.

  • mode – optional, “high” or “low” for how many bits to use when sampling. default=’low’. Set to “high” for correct sampling at small values of p. When sampling in float32, bernoulli samples with mode=’low’ produce incorrect results for p < ~1E-7. mode=”high” approximately doubles the cost of sampling.

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with boolean dtype and shape given by shape if shape is not None, or else p.shape.

beta(a, b, shape=None, dtype=None, *, out_sharding=None)#

Sample Beta random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x;a,b) \propto x^{a - 1}(1 - x)^{b - 1}\]

on the domain \(0 \le x \le 1\).

Parameters:
  • key – a PRNG key used as the random key.

  • a – a float or array of floats broadcast-compatible with shape representing the first parameter “alpha”.

  • b – a float or array of floats broadcast-compatible with shape representing the second parameter “beta”.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a and b. The default (None) produces a result shape by broadcasting a and b.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting a and b.

binomial(n, p, shape=None, dtype=None)#

Sample Binomial random values with given shape and float dtype.

The values are returned according to the probability mass function:

\[f(k;n,p) = \binom{n}{k}p^k(1-p)^{n-k}\]

on the domain \(0 < p < 1\), and where \(n\) is a nonnegative integer representing the number of trials and \(p\) is a float representing the probability of success of an individual trial.

Parameters:
  • key – a PRNG key used as the random key.

  • n – a float or array of floats broadcast-compatible with shape representing the number of trials.

  • p – a float or array of floats broadcast-compatible with shape representing the probability of success of an individual trial.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with n and p. The default (None) produces a result shape equal to np.broadcast(n, p).shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by np.broadcast(n, p).shape.

bits(shape=(), dtype=None, *, out_sharding=None)#

Sample uniform bits in the form of unsigned integers.

Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, an unsigned integer dtype for the returned values (default uint64 if jax_enable_x64 is true, otherwise uint32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

broadcast(k)[source]#

Broadcasts the RNG stream to a new shape.

This method generates a new key from the stream and broadcasts it to the specified shape k prepended to the key’s shape. It returns a new RngStream instance with this broadcasted key.

Parameters:

k – The shape to broadcast to. If an integer is provided, it is treated as a single-element tuple (k,).

categorical(logits, axis=-1, shape=None, replace=True, mode=None, *, out_sharding=None)#

Sample random values from categorical distributions.

Sampling with replacement uses the Gumbel max trick. Sampling without replacement uses the Gumbel top-k trick. See [1] for reference.

Parameters:
  • key – a PRNG key used as the random key.

  • logits – Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

  • axis – Axis along which logits belong to the same categorical distribution.

  • shape – Optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with np.delete(logits.shape, axis). The default (None) produces a result shape equal to np.delete(logits.shape, axis).

  • replace – If True (default), perform sampling with replacement. If False, perform sampling without replacement.

  • mode – optional, “high” or “low” for how many bits to use in the gumbel sampler. The default is determined by the use_high_dynamic_range_gumbel config, which defaults to “low”. With mode=”low”, in float32 sampling will be biased for events with probability less than about 1E-7; with mode=”high” this limit is pushed down to about 1E-14. mode=”high” approximately doubles the cost of sampling.

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with int dtype and shape given by shape if shape is not None, or else np.delete(logits.shape, axis).

References

cauchy(shape=(), dtype=None, *, out_sharding=None)#

Sample Cauchy random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) \propto \frac{1}{x^2 + 1}\]

on the domain \(-\infty < x < \infty\)

Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

chisquare(df, shape=None, dtype=None, *, out_sharding=None)#

Sample Chisquare random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; \nu) \propto x^{\nu/2 - 1}e^{-x/2}\]

on the domain \(0 < x < \infty\), where \(\nu > 0\) represents the degrees of freedom, given by the parameter df.

Parameters:
  • key – a PRNG key used as the random key.

  • df – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

choice(a, shape=(), replace=True, p=None, axis=0, mode=None)#

Generates a random sample from a given array.

Warning

If p has fewer non-zero elements than the requested number of samples, as specified in shape, and replace=False, the output of this function is ill-defined. Please make sure to use appropriate inputs.

Parameters:
  • key – a PRNG key used as the random key.

  • a – array or int. If an ndarray, a random sample is generated from its elements. If an int, the random sample is generated as if a were arange(a).

  • shape – tuple of ints, optional. Output shape. If the given shape is, e.g., (m, n), then m * n samples are drawn. Default is (), in which case a single value is returned.

  • replace – boolean. Whether the sample is with or without replacement. Default is True.

  • p – 1-D array-like, The probabilities associated with each entry in a. If not given the sample assumes a uniform distribution over all entries in a.

  • axis – int, optional. The axis along which the selection is performed. The default, 0, selects by row.

  • mode – optional, “high” or “low” for how many bits to use in the gumbel sampler when p is None and replace = False. The default is determined by the use_high_dynamic_range_gumbel config, which defaults to “low”. With mode=”low”, in float32 sampling will be biased for choices with probability less than about 1E-7; with mode=”high” this limit is pushed down to about 1E-14. mode=”high” approximately doubles the cost of sampling.

Returns:

An array of shape shape containing samples from a.

dirichlet(alpha, shape=None, dtype=None, *, out_sharding=None)#

Sample Dirichlet random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(\{x_i\}; \{\alpha_i\}) \propto \prod_{i=1}^k x_i^{\alpha_i - 1}\]

Where \(k\) is the dimension, and \(\{x_i\}\) satisfies

\[\sum_{i=1}^k x_i = 1\]

and \(0 \le x_i \le 1\) for all \(x_i\).

Parameters:
  • key – a PRNG key used as the random key.

  • alpha – an array of shape (..., n) used as the concentration parameter of the random variables.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last element of value n. Must be broadcast-compatible with alpha.shape[:-1]. The default (None) produces a result shape equal to alpha.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and shape given by shape + (alpha.shape[-1],) if shape is not None, or else alpha.shape.

double_sided_maxwell(loc, scale, shape=(), dtype=None)#

Sample from a double sided Maxwell distribution.

The values are distributed according to the probability density function:

\[f(x;\mu,\sigma) \propto z^2 e^{-z^2 / 2}\]

where \(z = (x - \mu) / \sigma\), with the center \(\mu\) specified by loc and the scale \(\sigma\) specified by scale.

Parameters:
  • key – a PRNG key.

  • loc – The location parameter of the distribution.

  • scale – The scale parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples.

exponential(shape=(), dtype=None, *, out_sharding=None)#

Sample Exponential random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = e^{-x}\]

on the domain \(0 \le x < \infty\).

Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

f(dfnum, dfden, shape=None, dtype=None)#

Sample F-distribution random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; \nu_1, \nu_2) \propto x^{\nu_1/2 - 1}\left(1 + \frac{\nu_1}{\nu_2}x\right)^{ -(\nu_1 + \nu_2) / 2}\]

on the domain \(0 < x < \infty\). Here \(\nu_1\) is the degrees of freedom of the numerator (dfnum), and \(\nu_2\) is the degrees of freedom of the denominator (dfden).

Parameters:
  • key – a PRNG key used as the random key.

  • dfnum – a float or array of floats broadcast-compatible with shape representing the numerator’s df of the distribution.

  • dfden – a float or array of floats broadcast-compatible with shape representing the denominator’s df of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with dfnum and dfden. The default (None) produces a result shape equal to dfnum.shape, and dfden.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

gamma(a, shape=None, dtype=None, *, out_sharding=None)#

Sample Gamma random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x;a) \propto x^{a - 1} e^{-x}\]

on the domain \(0 \le x < \infty\), with \(a > 0\).

This is the standard gamma density, with a unit scale/rate parameter. Dividing the sample output by the rate is equivalent to sampling from gamma(a, rate), and multiplying the sample output by the scale is equivalent to sampling from gamma(a, scale).

Parameters:
  • key – a PRNG key used as the random key.

  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

See also

loggammasample gamma values in log-space, which can provide improved

accuracy for small values of a.

generalized_normal(p, shape=(), dtype=None)#

Sample from the generalized normal distribution.

The values are returned according to the probability density function:

\[f(x;p) \propto e^{-|x|^p}\]

on the domain \(-\infty < x < \infty\), where \(p > 0\) is the shape parameter.

Parameters:
  • key – a PRNG key used as the random key.

  • p – a float representing the shape parameter.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified shape and dtype.

geometric(p, shape=None, dtype=None)#

Sample Geometric random values with given shape and float dtype.

The values are returned according to the probability mass function:

\[f(k;p) = p(1-p)^{k-1}\]

on the domain \(0 < p < 1\).

Parameters:
  • key – a PRNG key used as the random key.

  • p – a float or array of floats broadcast-compatible with shape representing the probability of success of an individual trial.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with p. The default (None) produces a result shape equal to np.shape(p).

  • dtype – optional, a int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by p.shape.

gumbel(shape=(), dtype=None, mode=None, *, out_sharding=None)#

Sample Gumbel random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = e^{-(x + e^{-x})}\]
Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • mode – optional, “high” or “low” for how many bits to use when sampling. The default is determined by the use_high_dynamic_range_gumbel config, which defaults to “low”. When drawing float32 samples, with mode=”low” the uniform resolution is such that the largest possible gumbel logit is ~16; with mode=”high” this is increased to ~32, at approximately double the computational cost.

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

laplace(shape=(), dtype=None, *, out_sharding=None)#

Sample Laplace random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{1}{2}e^{-|x|}\]
Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

loggamma(a, shape=None, dtype=None, *, out_sharding=None)#

Sample log-gamma random values with given shape and float dtype.

This function is implemented such that the following will hold for a dtype-appropriate tolerance:

np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)

The benefit of log-gamma is that for samples very close to zero (which occur frequently when a << 1) sampling in log space provides better precision.

Parameters:
  • key – a PRNG key used as the random key.

  • a – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with a. The default (None) produces a result shape equal to a.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by a.shape.

See also

gamma : standard gamma sampler.

logistic(shape=(), dtype=None, *, out_sharding=None)#

Sample logistic random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}\]
Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

lognormal(sigma=np.float32(1.0), shape=None, dtype=None)#

Sample lognormal random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x) = \frac{1}{x\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(\log x)^2}{2\sigma^2}\right)\]

on the domain \(x > 0\).

Parameters:
  • key – a PRNG key used as the random key.

  • sigma – a float or array of floats broadcast-compatible with shape representing the standard deviation of the underlying normal distribution. Default 1.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. The default (None) produces a result shape equal to ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape.

maxwell(shape=(), dtype=None)#

Sample from a one sided Maxwell distribution.

The values are distributed according to the probability density function:

\[f(x) \propto x^2 e^{-x^2 / 2}\]

on the domain \(0 \le x < \infty\).

Parameters:
  • key – a PRNG key.

  • shape – The shape of the returned samples.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples, of shape shape.

multinomial(n, p, *, shape=None, dtype=None, unroll=1)#

Sample from a multinomial distribution.

The probability mass function is

\[f(x;n,p) = \frac{n!}{x_1! \ldots x_k!} p_1^{x_1} \ldots p_k^{x_k}\]
Parameters:
  • key – PRNG key.

  • n – number of trials. Should have shape broadcastable to p.shape[:-1].

  • p – probability of each outcome, with outcomes along the last axis.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape, that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with p.shape[:-1]. The default (None) produces a result shape equal to p.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • unroll – optional, unroll parameter passed to jax.lax.scan() inside the implementation of this function.

Returns:

An array of counts for each outcome with the specified dtype and with shape

p.shape if shape is None, otherwise shape + (p.shape[-1],).

multivariate_normal(mean, cov, shape=None, dtype=None, method='cholesky', *, out_sharding=None)#

Sample multivariate normal random values with given mean and covariance.

The values are returned according to the probability density function:

\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]

where \(k\) is the dimension, \(\mu\) is the mean (given by mean) and \(\Sigma\) is the covariance matrix (given by cov).

Parameters:
  • key – a PRNG key used as the random key.

  • mean – a mean vector of shape (..., n).

  • cov – a positive definite covariance matrix of shape (..., n, n). The batch shape ... must be broadcast-compatible with that of mean.

  • shape – optional, a tuple of nonnegative integers specifying the result batch shape; that is, the prefix of the result shape excluding the last axis. Must be broadcast-compatible with mean.shape[:-1] and cov.shape[:-2]. The default (None) produces a result batch shape by broadcasting together the batch shapes of mean and cov.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • method – optional, a method to compute the factor of cov. Must be one of ‘svd’, ‘eigh’, and ‘cholesky’. Default ‘cholesky’. For singular covariance matrices, use ‘svd’ or ‘eigh’.

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and shape given by shape + mean.shape[-1:] if shape is not None, or else broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:].

normal(shape=(), dtype=None, *, out_sharding=None)#

Sample standard normal random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}\]

on the domain \(-\infty < x < \infty\)

Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

orthogonal(n, shape=(), dtype=None, m=None, *, out_sharding=None)#

Sample uniformly from the orthogonal group O(n).

If the dtype is complex, sample uniformly from the unitary group U(n).

For unequal rows and columns, this samples a semi-orthogonal matrix instead. That is, if \(A\) is the resulting matrix and \(A^*\) is its conjugate transpose, then:

  • If \(n \leq m\), the rows are mutually orthonormal: \(A A^* = I_n\).

  • If \(m \leq n\), the columns are mutually orthonormal: \(A^* A = I_m\).

Parameters:
  • key – a PRNG key used as the random key.

  • n – an integer indicating the number of rows.

  • shape – optional, the batch dimensions of the result. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • m – an integer indicating the number of columns. Defaults to n.

  • out_sharding

    optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array of shape (*shape, n, m) and specified dtype.

References

pareto(b, shape=None, dtype=None, *, out_sharding=None)#

Sample Pareto random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(x; b) = b / x^{b + 1}\]

on the domain \(1 \le x < \infty\) with \(b > 0\)

Parameters:
  • key – a PRNG key used as the random key.

  • b – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with b. The default (None) produces a result shape equal to b.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by b.shape.

permutation(x, axis=0, independent=False, *, out_sharding=None)#

Returns a randomly permuted array or range.

Parameters:
  • key – a PRNG key used as the random key.

  • x – int or array. If x is an integer, randomly shuffle np.arange(x). If x is an array, randomly shuffle its elements.

  • axis – int, optional. The axis which x is shuffled along. Default is 0.

  • independent – bool, optional. If set to True, each individual vector along the given axis is shuffled independently. Default is False.

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A shuffled version of x or array range

poisson(lam, shape=None, dtype=None, *, out_sharding=None)#

Sample Poisson random values with given shape and integer dtype.

The values are distributed according to the probability mass function:

\[f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!}\]

Where k is a non-negative integer and \(\lambda > 0\).

Parameters:
  • key – a PRNG key used as the random key.

  • lam – rate parameter (mean of the distribution), must be >= 0. Must be broadcast-compatible with shape

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default (None) produces a result shape equal to lam.shape.

  • dtype – optional, a integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by ``lam.shape.

rademacher(shape=(), dtype=None, *, out_sharding=None)#

Sample from a Rademacher distribution.

The values are distributed according to the probability mass function:

\[f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))\]

on the domain \(k \in \{-1, 1\}\), where \(\delta(x)\) is the dirac delta function.

Parameters:
  • key – a PRNG key.

  • shape – The shape of the returned samples. Default ().

  • dtype – The type used for samples.

  • out_sharding

    optional, specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.

randint(shape, minval, maxval, dtype=None, *, out_sharding=None)#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNG key used as the random key.

  • shape – a tuple of nonnegative integers representing the shape.

  • minval – int or array of ints broadcast-compatible with shape, a minimum (inclusive) value for the range.

  • maxval – int or array of ints broadcast-compatible with shape, a maximum (exclusive) value for the range.

  • dtype – optional, an int dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

Note

randint() uses a modulus-based computation that is known to produce slightly biased values in some cases. The magnitude of the bias scales as (maxval - minval) * ((2 ** nbits ) % (maxval - minval)) / 2 ** nbits: in words, the bias goes to zero when (maxval - minval) is a power of 2, and otherwise the bias will be small whenever (maxval - minval) is small compared to the range of the sampled type.

To reduce this bias, 8-bit and 16-bit values will always be sampled at 32-bit and then cast to the requested type. If you find yourself sampling values for which this bias may be problematic, a possible alternative is to sample via uniform:

def randint_via_uniform(key, shape, minval, maxval, dtype):
  u = jax.random.uniform(key, shape, minval=minval - 0.5, maxval=maxval - 0.5)
  return u.round().astype(dtype)

But keep in mind this method has its own biases due to floating point rounding errors, and in particular there may be some integers in the range [minval, maxval) that are impossible to produce with this approach.

rayleigh(scale, shape=None, dtype=None, *, out_sharding=None)#

Sample Rayleigh random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x;\sigma) \propto xe^{-x^2/(2\sigma^2)}\]

on the domain \(-\infty < x < \infty\), and where \(\sigma > 0\) is the scale parameter of the distribution.

Parameters:
  • key – a PRNG key used as the random key.

  • scale – a float or array of floats broadcast-compatible with shape representing the parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with scale. The default (None) produces a result shape equal to scale.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by scale.shape.

t(df, shape=(), dtype=None, *, out_sharding=None)#

Sample Student’s t random values with given shape and float dtype.

The values are distributed according to the probability density function:

\[f(t; \nu) \propto \left(1 + \frac{t^2}{\nu}\right)^{-(\nu + 1)/2}\]

Where \(\nu > 0\) is the degrees of freedom, given by the parameter df.

Parameters:
  • key – a PRNG key used as the random key.

  • df – a float or array of floats broadcast-compatible with shape representing the degrees of freedom parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with df. The default (None) produces a result shape equal to df.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by df.shape.

triangular(left, mode, right, shape=None, dtype=None)#

Sample Triangular random values with given shape and float dtype.

The values are returned according to the probability density function:

\[\begin{split}f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right.\end{split}\]

on the domain \(a \leq x \leq c\).

Parameters:
  • key – a PRNG key used as the random key.

  • left – a float or array of floats broadcast-compatible with shape representing the lower limit parameter of the distribution.

  • mode – a float or array of floats broadcast-compatible with shape representing the peak value parameter of the distribution, value must fulfill the condition left <= mode <= right.

  • right – a float or array of floats broadcast-compatible with shape representing the upper limit parameter of the distribution, must be larger than left.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with left,``mode`` and right. The default (None) produces a result shape equal to left.shape, mode.shape and right.shape.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by left.shape, mode.shape and right.shape.

truncated_normal(lower, upper, shape=None, dtype=None, *, out_sharding=None)#

Sample truncated standard normal random values with given shape and dtype.

The values are returned according to the probability density function:

\[f(x) \propto e^{-x^2/2}\]

on the domain \(\rm{lower} < x < \rm{upper}\).

Parameters:
  • key – a PRNG key used as the random key.

  • lower – a float or array of floats representing the lower bound for truncation. Must be broadcast-compatible with upper.

  • upper – a float or array of floats representing the upper bound for truncation. Must be broadcast-compatible with lower.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with lower and upper. The default (None) produces a result shape by broadcasting lower and upper.

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified dtype and shape given by shape if shape is not None, or else by broadcasting lower and upper. Returns values in the open interval (lower, upper).

uniform(shape=(), dtype=None, minval=0.0, maxval=1.0, *, out_sharding=None)#

Sample uniform random values in [minval, maxval) with given shape/dtype.

Parameters:
  • key – a PRNG key used as the random key.

  • shape – optional, a tuple of nonnegative integers representing the result shape. Default ().

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

  • minval – optional, a minimum (inclusive) value broadcast-compatible with shape for the range (default 0).

  • maxval – optional, a maximum (exclusive) value broadcast-compatible with shape for the range (default 1).

  • out_sharding

    Optional. Specifies how the output array should be sharded across devices in multi-device computation. Can be a NamedSharding, a PartitionSpec (P), or None (default). When specified, the output will be sharded according to the given sharding specification. Primarily used in explicit sharding mode. See the explicit sharding tutorial for more details.

Returns:

A random array with the specified shape and dtype.

wald(mean, shape=None, dtype=None)#

Sample Wald random values with given shape and float dtype.

The values are returned according to the probability density function:

\[f(x;\mu) = \frac{1}{\sqrt{2\pi x^3}} \exp\left(-\frac{(x - \mu)^2}{2\mu^2 x}\right)\]

on the domain \(-\infty < x < \infty\), and where \(\mu > 0\) is the location parameter of the distribution.

Parameters:
  • key – a PRNG key used as the random key.

  • mean – a float or array of floats broadcast-compatible with shape representing the mean parameter of the distribution.

  • shape – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with mean. The default (None) produces a result shape equal to np.shape(mean).

  • dtype – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by shape if shape is not None, or else by mean.shape.

weibull_min(scale, concentration, shape=(), dtype=None)#

Sample from a Weibull distribution.

The values are distributed according to the probability density function:

\[f(x;\sigma,c) \propto x^{c - 1} \exp(-(x / \sigma)^c)\]

on the domain \(0 < x < \infty\), where \(c > 0\) is the concentration parameter, and \(\sigma > 0\) is the scale parameter.

Parameters:
  • key – a PRNG key.

  • scale – The scale parameter of the distribution.

  • concentration – The concentration parameter of the distribution.

  • shape – The shape added to the parameters loc and scale broadcastable shape.

  • dtype – The type used for samples.

Returns:

A jnp.array of samples.

flax.nnx.split_rngs(node=<flax.typing.Missing object>, /, *, splits, only=Ellipsis, squeeze=False, graph=None, graph_updates=None)[source]#

Splits the (nested) Rng states of the given node.

Parameters:
  • node – the base node containing the rng states to split.

  • splits – an integer or tuple of integers specifying the shape of the split rng keys.

  • only – a Filter selecting which rng states to split.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • graph_updates – If True, applies the splits in-place on the node. If False, returns a new node with split rng states.

Returns:

If node is provided and both graph=True and graph_updates=True, splits the rng states of node and returns a SplitBackups iterable, allowing you to restore the previous state with nnx.restore_rngs.

If node is not provided, returns a decorator. The decorated function runs split_rngs on its first argument before processing it.

If node is provided and graph_updates=False`, returns a copy of ``node with forked rng states. The original node is left unmodified.

Example:

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> result = nnx.split_rngs(rngs, splits=5, graph_updates=False)
>>> result.params.key.shape, result.dropout.key.shape
((5,), (5,))

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> result = nnx.split_rngs(rngs, splits=(2, 5), graph_updates=False)
>>> result.params.key.shape, result.dropout.key.shape
((2, 5), (2, 5))


>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> result = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=False)
>>> result.params.key.shape, result.dropout.key.shape
((5,), ())

Once split, random state can be used with transforms like nnx.vmap():

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> in_axes = nnx.prefix(rngs, {'params': 0, ...: None}, graph=False)
>>> out_axes = nnx.prefix(Model(rngs), {nnx.Param: 0, ...: None}, graph=False)
...
>>> @nnx.vmap(in_axes=(in_axes,), out_axes=out_axes, graph=False)
... def create_model(rngs):
...   return Model(rngs)
...
>>> batch = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=False)
>>> model = create_model(batch)
...
>>> model.dropout.rngs.key.shape
()

When split_rngs is not given a node argument, it acts as a decorator:

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> in_axes = nnx.prefix(rngs, {'params': 0, ...: None}, graph=False)
>>> out_axes = nnx.prefix(Model(rngs), {nnx.Param: 0, ...: None}, graph=False)
...
>>> @nnx.split_rngs(splits=5, only='params')
... @nnx.vmap(in_axes=(in_axes,), out_axes=out_axes, graph=False)
... def auto_split_create_model(rngs):
...   return Model(rngs)
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> model = auto_split_create_model(rngs)

split_rngs returns a SplitBackups object that can be used to restore the original unsplit rng states using nnx.restore_rngs(), this is useful when you only want to split the rng states temporarily:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> backups = nnx.split_rngs(rngs, splits=5, only='params', graph=True, graph_updates=True)
>>> nnx.restore_rngs(backups)
...
>>> model.dropout.rngs.key.shape
()

SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> with nnx.split_rngs(rngs, splits=5, only='params', graph=True, graph_updates=True):
...   model = create_model(rngs)
...
>>> model.dropout.rngs.key.shape
()
flax.nnx.fork_rngs(node=<flax.typing.Missing object>, /, *, split=None, graph=None, graph_updates=None)[source]#

Forks the (nested) Rng states of the given node.

Parameters:
  • node – the base node containing the rng states to fork.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • graph_updates – If True, applies the forks in-place on the node. If False, returns a new node with split rng states.

Returns:

If node is provided and both graph=True and graph_updates=True, forks the rng states of node and returns a SplitBackups iterable, allowing you to restore the previous state with nnx.restore_rngs.

If node is not provided but graph=True and graph_updates=True, returns a decorator that forks the rng states of the inputs to the decorated function.

If node is provided and graph_updates=False`, returns a copy of ``node with forked rng states. The original node is left unmodified.

Example:

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> _ = nnx.fork_rngs(rngs)

fork_rngs with graph_updates=True returns a SplitBackups object that can be used to restore the original unforked rng states using nnx.restore_rngs(), this is useful when you only want to fork the rng states temporarily:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> backups = nnx.fork_rngs(rngs, graph=True, graph_updates=True)
>>> model = nnx.Linear(2, 3, rngs=rngs)
>>> nnx.restore_rngs(backups)
...

SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:

>>> rngs = nnx.Rngs(params=0, dropout=1)
...
>>> with nnx.fork_rngs(rngs, graph=True, graph_updates=True):
...   model = nnx.Linear(2, 3, rngs=rngs)
flax.nnx.reseed(node, /, *, graph=None, policy='scalars_only', **stream_keys)[source]#

Update the keys of the specified RNG streams with new keys.

Parameters:
  • node – the node to reseed the RNG streams in.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • policy – defines how the new scalar key is for each RngStream is used to reseed the stream. If 'scalars_only' is given (the default), an error is raised if the target stream key is not a scalar. If 'match_shape' is given, the new scalar key is split to match the shape of the target stream key. A callable of the form (path, scalar_key, target_shape) -> new_key can be passed to define a custom reseeding policy.

  • **stream_keys – a mapping of stream names to new keys. The keys can be either integers or jax.random.key.

Example:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)
flax.nnx.with_rngs(node=<flax.typing.Missing object>, /, *, split=None, fork=None, broadcast=None, only=True, graph=None, graph_updates=None)[source]#

Returns a copy of tree with RngStream objects replaced according to

split and fork rules.

split controls which streams are split — after splitting, each call to the stream produces one key from an array of pre-generated keys rather than a single key. fork controls which of the remaining streams are forked — each call to a forked stream produces a unique key derived from the parent counter. Streams that match neither rule are returned unchanged.

Parameters:
  • node – A pytree that may contain RngStream objects (e.g. an Rngs instance, a module, or any nested structure).

  • split – Specifies which streams to split and into what shape. Can be: * An int or tuple[int, ...] — split all streams into this shape, equivalent to {...: split}. * A Filter-keyed mapping where each value is an int or tuple[int, ...]. The first matching filter wins.

  • fork – A Filter, a sequence of filters, or None selecting which streams not already handled by split should be forked. Pass ... to fork all remaining streams.

  • broadcast

    Specifies which streams to broadcast and into what shape. Can be: * An int or tuple[int, ...] — broadcast all streams into

    this shape, equivalent to {...: broadcast}. * A Filter-keyed mapping where each value is an int or tuple[int, ...]. The first matching filter wins.

  • only – A Filter selecting which streams to process. Pass True (default) to process all streams.

  • graph – If True, uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

Returns:

A new tree of the same structure as tree with RngStream objects replaced by split or forked copies as specified.

Example — split all streams:

>>> from flax import nnx
...
>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> new_rngs = nnx.with_rngs(rngs, split=4, graph=False)
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.dropout.key.shape
(4,)

Example — split some streams, fork the rest:

>>> rngs = nnx.Rngs(params=0, dropout=1)
>>> new_rngs = nnx.with_rngs(
...   rngs, split={'params': 4}, fork=nnx.Not('params'), graph=False
... )
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.dropout.key.shape   # forked: scalar key, advanced counter
()

Example — per-filter split shapes:

>>> rngs = nnx.Rngs(params=0, dropout=1, noise=2)
>>> new_rngs = nnx.with_rngs(rngs, split={
...   'params': 4,    # split params into 4 keys
...   ...: (2, 4),    # split anything else into 2×4 keys
... }, graph=False)
>>> new_rngs.params.key.shape
(4,)
>>> new_rngs.noise.key.shape
(2, 4)