Randomness and PRNGs in Flax#

In this guide, you will learn how Flax uses JAX’s explicit pseudorandom number generator (PRNG) keys to emulate randomness, and adds some additional features to make it easier for users to thread PRNG keys through different Flax Modules.

If you are new to JAX PRNG keys or need a refresher, check out:

Setup#

Install or upgrade Flax, and then import some necessary dependencies.

Note: This guide uses the --xla_force_host_platform_device_count=8 flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device Google Cloud TPU environment, for example, on Google Cloud or in a Kaggle VM with a TPU.

!pip install -q flax
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import flax, flax.linen as nn
import jax, jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

import hashlib
jax.devices()
[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

Set the JAX config variable jax_threefry_partitionable to True. This will be the default value in the future and makes the PRNG more efficiently auto-parallelizable under jax.jit. Refer to JAX discussion for more details.

jax.config.update('jax_threefry_partitionable', True)
assert jax.config.jax_threefry_partitionable == True
assert jax.config.jax_default_prng_impl == 'threefry2x32'

Receiving, manipulating and creating PRNG keys with Module.make_rng#

The primary method Flax uses to receive, manipulate and create PRNG keys is via the Module method self.make_rng. It is a method that accepts a string name that represents an “RNG stream”. Each RNG stream has an initial starting seed PRNG key, which the user passes in as a dictionary argument (i.e. into an .init or .apply function), and the starting seed is used by self.make_rng to generate more PRNG keys for that stream.

Note that this method can only be called with bounded modules (see The Flax Module lifecycle).

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(self.make_rng('rng_stream'))
    print(self.make_rng('rng_stream'))
    print(self.make_rng('rng_stream'))

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

Now if we use a different starting seed PRNG key, we will generate different values (as intended).

variables = rng_module.init({'rng_stream': jax.random.key(1)})
Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]

Calling self.make_rng for one stream will not affect the random values generated from another stream; i.e. the call order doesn’t matter.

class RNGModuleTwoStreams(nn.Module):
  @nn.compact
  def __call__(self):
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")
    # same value as first code snippet above
    print(f"rng_stream1: {self.make_rng('rng_stream1')}")
    # same value as second code snippet above
    print(f"rng_stream2: {self.make_rng('rng_stream2')}")

rng_module_two_streams = RNGModuleTwoStreams()
variables = rng_module_two_streams.init(
  {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(1)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3077990774 2166202870]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3825832496 2886313970]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[ 791337683 1373966058]

Providing the same seed PRNG key will result in the same values being generated (provided that the same operations are used for those keys).

variables = rng_module_two_streams.init(
  {'rng_stream1': jax.random.key(0), 'rng_stream2': jax.random.key(0)}
)
rng_stream1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
rng_stream1: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
rng_stream2: Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

How self.make_rng works under the hood#

This is what happens when self.make_rng (flax.linen.Module.make_rng) is called:

  • The following data is collected:

    • The path of the Module as provided by self.scope.path (the top-level root module has an empty path ()).

    • The self.make_rng call count. That is, the number of times self.make_rng has been called for this specific stream (including this call).

      • Note: Each sub-Module will have its own individual call count that’s separate from other Modules. For example, a Module that has called self.make_rng('params') twice and contains a sub-Module that has called self.make_rng('params') once, will have a call count of 2 and 1 for each of the RNG stream 'params', respectively.

  • The data is bundled into a tuple and fed into a hash function and produces an integer.

  • The generated integer is folded into the RNG stream’s starting seed PRNG key to generate a new, unique PRNG key.

Below is a slightly simplified version of the hash function that Flax uses for self.make_rng:

def produce_hash(data):
  m = hashlib.sha1()
  for x in data:
    if isinstance(x, str):
      m.update(x.encode('utf-8'))
    elif isinstance(x, int):
      m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big'))
    else:
      raise ValueError(f'Expected int or string, got: {x}')
  d = m.digest()
  hash_int = int.from_bytes(d[:4], byteorder='big')
  return hash_int

And now you can manually reproduce the PRNG keys generated from self.make_rng:

stream_seed = jax.random.key(0)
for call_count in range(1, 4):
  hash_int = produce_hash(data=(call_count,))
  print(jax.random.fold_in(stream_seed, jnp.uint32(hash_int)))
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]
variables = rng_module.init({'rng_stream': jax.random.key(0)})
Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
Array((), dtype=key<fry>) overlaying:
[2411773124 4124888837]

Sub-Modules and self.make_rng#

This section explores how self.make_rng (flax.linen.Module.make_rng) behaves with sub-Modules.

Consider the following example:

class RNGSubSubModule(nn.Module):
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")

class RNGSubModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")
    RNGSubSubModule()()

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
    print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
    RNGSubModule()()

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]

As previously discussed, the data that is fed into the Flax hash function consists of:

  • The path of the Module, provided by self.scope.path (the top-level root module has an empty path ()); and

  • The call count for the specific RNG stream.

In addition, note that each Flax Module and sub-Module have their own individual call counts, even for the same RNG stream. The convention for sub-Module names is: f'{module_name}_{module_number}'. For example, the first Dense sub-Module will be called Dense_0, the second one will be called Dense_1, and so on.

Therefore, the following data will be fed into the hash function:

  • For RNGModule: The data is just the call count, such as (1,) and (2,), since the root Module has an empty path.

  • For RNGSubModule: The data is ('RNGSubModule_0', 1) and ('RNGSubModule_0', 2).

  • For RNGSubSubModule: The data is ('RNGSubModule_0', 'RNGSubSubModule_0', 1) and ('RNGSubModule_0', 'RNGSubSubModule_0', 2).

With this data, you can manually reproduce the PRNG keys generated from the Module and sub-Modules using self.make_rng.

For example:

stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_0', 'RNGSubSubModule_0')):
  if initial_data:
    module_name = initial_data[-1]
  else:
    module_name = 'RNGModule'
  for call_count in (1, 2):
    hash_int = produce_hash(data=initial_data+(call_count,))
    rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
    print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[ 234240654 1028548813]
RNGSubSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[3650462303 2124609379]

If the same sub-Module class is used multiple times, you can increment the suffix of the sub-Module name accordingly. For example: RNGSubModule_0, RNGSubModule_1, and so on.

class RNGSubModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"{self.name}, count 1: {self.make_rng('rng_stream')}")
    print(f"{self.name}, count 2: {self.make_rng('rng_stream')}")

class RNGModule(nn.Module):
  @nn.compact
  def __call__(self):
    print(f"RNGModule, count 1: {self.make_rng('rng_stream')}")
    print(f"RNGModule, count 2: {self.make_rng('rng_stream')}")
    RNGSubModule()()
    RNGSubModule()()

rng_module = RNGModule()
variables = rng_module.init({'rng_stream': jax.random.key(0)})
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]
stream_seed = jax.random.key(0)
for initial_data in ((), ('RNGSubModule_0',), ('RNGSubModule_1',)):
  if initial_data:
    module_name = initial_data[-1]
  else:
    module_name = 'RNGModule'
  for call_count in (1, 2):
    hash_int = produce_hash(data=initial_data+(call_count,))
    rng_key = jax.random.fold_in(stream_seed, jnp.uint32(hash_int))
    print(f"{module_name}, count {call_count}: {rng_key}")
RNGModule, count 1: Array((), dtype=key<fry>) overlaying:
[1428664606 3351135085]
RNGModule, count 2: Array((), dtype=key<fry>) overlaying:
[3456700291 3873160899]
RNGSubModule_0, count 1: Array((), dtype=key<fry>) overlaying:
[3858825717 2323087578]
RNGSubModule_0, count 2: Array((), dtype=key<fry>) overlaying:
[ 601859108 3782857444]
RNGSubModule_1, count 1: Array((), dtype=key<fry>) overlaying:
[ 426957352 2006350344]
RNGSubModule_1, count 2: Array((), dtype=key<fry>) overlaying:
[4006253729 4205356731]

Using self.param and self.variable#

Flax users have the option of creating additional parameters and variables in their modules by using the self.param and self.variable Module methods. An init_fn argument must be passed to these methods so that it can generate the initial value of the parameter/variable. self.make_rng is commonly used implicitly or explicitly in this init_fn, since many initializer functions are stochastic in nature and require a PRNG key. See the full list of Flax initializers here.

There are a couple of differences between the two methods that the user should take note of:

  • self.param always creates a parameter in the 'params' collection, whereas self.variable creates a variable in any collection the user specifies

  • self.param will automatically call self.make_rng('params') and pass in the generated PRNG key implicitly to the init_fn of the parameter you instantiated (it will be passed in as the first argument), whereas users will have to manually specify what RNG stream to call self.make_rng on in the init_fn of self.variable (it could be 'params' or something different).

Below is an example using both self.param and self.variable:

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    # kernel will use 'params' seed, initial data will include 'Dense_0', call count 1
    x = nn.Dense(2, kernel_init=jax.random.normal, use_bias=False)(x)
    # model_param will use 'params' seed, call count 1
    model_param = self.param('model_param', jax.random.normal, x.shape)
    # model_variable1 will use 'params' seed, call count 2
    model_variable1 = self.variable(
      'other_collection',
      'model_variable1',
      lambda: jax.random.normal(self.make_rng('params'), x.shape),
    )
    # model_variable2 will use 'other' seed, call count 1
    model_variable2 = self.variable(
      'other_collection',
      'model_variable2',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    # kernel will use 'params' seed, initial data will include 'Dense_1', call count 1
    # bias will use 'params' seed, initial data will include 'Dense_1', call count 2
    x = nn.Dense(2, kernel_init=jax.random.normal, bias_init=jax.random.normal)(
      x
    )
    return x

model = Model()
variables = model.init(
  {'params': jax.random.key(0), 'other': jax.random.key(1)}, jnp.ones((2, 2))
)
print(variables['params']['Dense_0']['kernel'])
print(variables['params']['model_param'])
print(variables['other_collection']['model_variable1'])
print(variables['other_collection']['model_variable2'])
print(variables['params']['Dense_1']['kernel'])
print(variables['params']['Dense_1']['bias'])
[[-1.6185919   0.700908  ]
 [-1.3146383  -0.79342234]]
[[ 0.0761425 -1.6157459]
 [-1.6857724  0.7126891]]
[[ 0.60175574  0.2553228 ]
 [ 0.27367848 -2.1975214 ]]
[[1.6249592  0.30813068]
 [1.6613585  1.0404155 ]]
[[ 0.0030665   0.29551846]
 [ 0.16670242 -0.78252524]]
[1.582462   0.15216611]

Remember:

  • there is a separate count for each RNG stream; this is why the count for self.make_rng('other') starts at 1 even though there were earlier calls of self.make_rng('params')

  • each submodule has their own separate count for each rng stream; this is why each Dense layer has their own separate count for self.make_rng('params') and why model_param and model_variable1 share the same count (since they are defined within the same top-level parent module)

params_seed = jax.random.key(0)
other_seed = jax.random.key(1)
for initial_data, count, seed, shape in (
  (('Dense_0',), 1, params_seed, (2, 2)),
  ((), 1, params_seed, (2, 2)),
  ((), 2, params_seed, (2, 2)),
  ((), 1, other_seed, (2, 2)),
  (('Dense_1',), 1, params_seed, (2, 2)),
  (('Dense_1',), 2, params_seed, (1, 2)),
):
  hash_int = produce_hash(data=(*initial_data, count))
  rng_key = jax.random.fold_in(seed, jnp.uint32(hash_int))
  print(jax.random.normal(rng_key, shape))
[[-1.6185919   0.700908  ]
 [-1.3146383  -0.79342234]]
[[ 0.0761425 -1.6157459]
 [-1.6857724  0.7126891]]
[[ 0.60175574  0.2553228 ]
 [ 0.27367848 -2.1975214 ]]
[[1.6249592  0.30813068]
 [1.6613585  1.0404155 ]]
[[ 0.0030665   0.29551846]
 [ 0.16670242 -0.78252524]]
[[1.582462   0.15216611]]

Managing RNG streams inside a training loop#

Below is an example of managing RNG streams from self.make_rng, self.param, self.variable and nn.Dropout in a training loop (note: nn.Dropout requires a seed PRNG key to be passed in the 'dropout' RNG stream, since it implicitly calls self.make_rng('dropout')):

class SubModule(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    # variables created using `self.param` will use `self.make_rng('params')`
    kernel = self.param('submodule_kernel', jax.random.normal, x.shape)
    x = x + kernel
    # `nn.Dropout` will use self.make_rng('dropout')
    x = nn.Dropout(0.2)(x, deterministic=not train)
    # `nn.Dense` will use self.make_rng('params')
    x = nn.Dense(3)(x)
    return x

class Model(nn.Module):
  @nn.compact
  def __call__(self, x, train):
    # make kernel use `self.make_rng('other')`
    kernel = self.variable(
      'other_collection',
      'module_kernel',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    x = (
      x + kernel.value
    )  # `.value` will extract the underlying value of the variable
    x = SubModule()(x, train)
    # `nn.Dropout` will use self.make_rng('dropout')
    x = nn.Dropout(0.2)(x, deterministic=not train)
    # `nn.Dense` will use self.make_rng('params')
    x = nn.Dense(2)(x)
    return x

params_rng, other_rng, train_rng = jax.random.split(jax.random.key(0), 3)
init_rngs = {'params': params_rng, 'other': other_rng}

x = jnp.ones((1, 3))
y = jnp.ones((1, 2))

module = Model()
variables = module.init(init_rngs, x, train=False)
def update(variables, rng):
  # we don't need to provide a 'params' or 'other' rng, as only 'dropout' rng will be used during training
  # split the rng to get a dropout_rng to be used for this training iteration,
  # and to get another rng key to be used for the next training iteration
  dropout_rng, next_rng = jax.random.split(rng)
  def loss(params):
    out = module.apply(
      {'params': params, 'other_collection': variables['other_collection']},
      x,
      train=True,
      rngs={'dropout': dropout_rng},
    )
    return jnp.mean((y - out) ** 2)
  grads = jax.grad(loss)(variables['params'])
  params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, variables['params'], grads)
  return {
    'params': params,
    'other_collection': variables['other_collection'],
  }, next_rng

for _ in range(10):
  variables, train_rng = update(variables, train_rng)
  out = module.apply(variables, x, train=False)
  print(jnp.mean((y - out)**2))
2.518454
2.4859657
2.4171872
2.412684
2.3435805
2.2773488
2.2592616
2.2009292
2.1839895
2.1707344

🔪 Sharp edge 🔪 - unintentionally generating the same values#

There is an edge case where the same value can be unintentionally generated. See the Flax issue for more details.

class Leaf(nn.Module):
  def __call__(self, x):
    return x + jax.random.randint(self.make_rng("rng"), (), 0, 100)

class Node(nn.Module):
  leaf_name: str
  @nn.compact
  def __call__(self, x):
    return Leaf(name=self.leaf_name)(x)

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    return (Node(name="ab", leaf_name="cdef")(x),
            Node(name="abc", leaf_name="def")(x),
    )

out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # same output, despite having different submodule names
Array(True, dtype=bool)

This occurs because the hash function concatenates strings together, so the data ('AB', 'C') is equivalent to data ('A', 'BC') when fed into the hash function, therefore producing the same hash int.

print(produce_hash(data=('A', 'B', 'C', 1)))
print(produce_hash(data=('AB', 'C', 1)))
print(produce_hash(data=('A', 'BC', 1)))
print(produce_hash(data=('ABC', 1)))
947574064
947574064
947574064
947574064

To avoid this edge case, users can flip the flax_fix_rng_separator configuration flag to True.

flax.config.update('flax_fix_rng_separator', True)
out1, out2 = Model().apply({}, 0, rngs={"rng": jax.random.key(33)})
out1 == out2 # different output
Array(False, dtype=bool)

Managing RNG’s on multiple devices#

This section will show examples on how to use jit and shard_map to use RNG’s in multi-device settings.

Using jax.jit#

When using jax.jit, we can use RNG’s as we did before, but we now include in_shardings and out_shardings arguments to shard specify how to shard input and output data.

For more details on training on multiple devices in Flax using jax.jit, see our Scale up Flax Modules on multiple devices guide and lm1b example.

# Create a mesh and annotate the axis with a name.
device_mesh = mesh_utils.create_device_mesh((8,))
print(device_mesh)

mesh = Mesh(devices=device_mesh, axis_names=('data',))
print(mesh)

data_sharding = NamedSharding(mesh, PartitionSpec('data',))
print(data_sharding)
[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)
 CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]
Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('data',))
NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec('data',))
class Model(nn.Module):
  @nn.compact
  def __call__(self, x, add_noise):
    x = nn.Dense(1)(x)
    # use jnp.where for control flow; for more details see: https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
    return jnp.where(
      add_noise, x + jax.random.normal(self.make_rng('params'), x.shape), x
    )

module = Model()
init_rng, apply_rng = jax.random.split(jax.random.key(0))
x = jnp.ones((8, 1))
variables = module.init(init_rng, x, False)

# create custom forward function, since jit does not support kwargs when in_shardings is specified
def forward(variables, x, add_noise, rng):
  return module.apply(variables, x, add_noise, rngs={'params': rng})

# shard the inputs x across devices
# replicate the variables, add_noise boolean and rng key across devices
# shard the output across devices
jit_forward = jax.jit(
  forward,
  in_shardings=(None, data_sharding, None, None),
  out_shardings=data_sharding,
)
out = jit_forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
       [-2.8055234 ],
       [-2.5464187 ],
       [ 1.0270392 ],
       [-3.5243359 ],
       [-2.2795477 ],
       [-0.6504516 ],
       [ 0.17373264]], dtype=float32)

The output is different given the same input, meaning the RNG key was used to add noise to the output.

We can also confirm that the output is sharded across devices:

out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-2.2187614]]),
 Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-2.8055234]]),
 Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-2.5464187]]),
 Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[1.0270392]]),
 Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-3.5243359]]),
 Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-2.2795477]]),
 Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-0.6504516]]),
 Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.17373264]])]

Another way to visualize the output sharding:

jax.debug.visualize_array_sharding(out)
  CPU 0  
         
  CPU 1  
         
  CPU 2  
         
  CPU 3  
         
  CPU 4  
         
  CPU 5  
         
  CPU 6  
         
  CPU 7  
         

If we choose not to add noise, then the output is the same across all batches (as expected, since the input is the same for all batches):

out = jit_forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764]], dtype=float32)

We can confirm the un-jitted function produces the same values, albeit unsharded (note there may be small numerical differences due to compiler optimizations from jitting):

out = forward(variables, x, True, apply_rng)
out
Array([[-2.2187614 ],
       [-2.8055234 ],
       [-2.5464187 ],
       [ 1.0270392 ],
       [-3.5243359 ],
       [-2.2795477 ],
       [-0.6504516 ],
       [ 0.17373264]], dtype=float32)
out = forward(variables, x, False, apply_rng)
out
Array([[-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764],
       [-1.2839764]], dtype=float32)

Using shard_map#

When using jax.experimental.shard_map.shard_map, the important parts to remember are to:

  • split your PRNG key to produce a different key for each device

  • the PRNG keys will be sharded automatically to each device (provided you use the correct partition specification), but the rank of the original batched PRNG key array will not be reduced; e.g. with a batch of 8 PRNG keys and 8 devices, each device will see a PRNG key batch of size 1 within the shard_map-ed function

    • therefore to access the PRNG key itself, we need to index slice into it (see the example below)

def forward(variables, x, add_noise, rng_key_batch):
  # rng_key_batch is a batch of size 1 containing 1 PRNG key
  # index slice into the rng_key_batch to access the PRNG key
  return module.apply(
    variables, x, add_noise, rngs={'params': rng_key_batch[0]}
  )

# define partition specifications
data_pspec = PartitionSpec('data')
no_pspec = PartitionSpec()

# shard the inputs x and rng keys across devices
# replicate the variables and add_noise boolean across devices
# shard the output across devices
shmap_forward = shard_map(
  forward,
  mesh=mesh,
  in_specs=(no_pspec, data_pspec, no_pspec, data_pspec),
  out_specs=data_pspec,
)
# get 8 different rng's that will be used by the 8 devices when doing forward inference
apply_rngs = jax.random.split(apply_rng, 8)
out = shmap_forward(variables, x, True, apply_rngs)
out
Array([[-1.2605132 ],
       [-1.2405176 ],
       [-0.99350417],
       [-1.0277128 ],
       [-1.4154483 ],
       [-0.3905797 ],
       [-2.417677  ],
       [ 0.9023453 ]], dtype=float32)

Confirm that the output is sharded across devices:

out.addressable_shards
[Shard(device=CpuDevice(id=0), index=(slice(0, 1, None), slice(None, None, None)), replica_id=0, data=[[-1.2605132]]),
 Shard(device=CpuDevice(id=1), index=(slice(1, 2, None), slice(None, None, None)), replica_id=0, data=[[-1.2405176]]),
 Shard(device=CpuDevice(id=2), index=(slice(2, 3, None), slice(None, None, None)), replica_id=0, data=[[-0.99350417]]),
 Shard(device=CpuDevice(id=3), index=(slice(3, 4, None), slice(None, None, None)), replica_id=0, data=[[-1.0277128]]),
 Shard(device=CpuDevice(id=4), index=(slice(4, 5, None), slice(None, None, None)), replica_id=0, data=[[-1.4154483]]),
 Shard(device=CpuDevice(id=5), index=(slice(5, 6, None), slice(None, None, None)), replica_id=0, data=[[-0.3905797]]),
 Shard(device=CpuDevice(id=6), index=(slice(6, 7, None), slice(None, None, None)), replica_id=0, data=[[-2.417677]]),
 Shard(device=CpuDevice(id=7), index=(slice(7, 8, None), slice(None, None, None)), replica_id=0, data=[[0.9023453]])]
jax.debug.visualize_array_sharding(out)
  CPU 0  
         
  CPU 1  
         
  CPU 2  
         
  CPU 3  
         
  CPU 4  
         
  CPU 5  
         
  CPU 6  
         
  CPU 7  
         

Lifted transforms#

Flax lifted transforms allow you to use JAX transforms with Module arguments. This section will show you how to control how PRNG keys are split in Flax lifted transforms.

Refer to Lifted transformations for more detail.

nn.vmap#

We can use nn.vmap to create a batched Dense layer:

x = jnp.ones((3, 2))

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': None},
    split_rngs={'params': False})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([0., 0.], dtype=float32),
  'kernel': Array([[-1.2488099 , -0.6127134 ],
         [-0.07084481,  0.60130936]], dtype=float32)}}

By denoting variable_axes={'params': 0}', we vectorize the params Arrays on the first axis. However the parameter values generated are all identical to each other:

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': False})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
         [0., 0.],
         [0., 0.]], dtype=float32),
  'kernel': Array([[[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]],
  
         [[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]],
  
         [[-1.2488099 , -0.6127134 ],
          [-0.07084481,  0.60130936]]], dtype=float32)}}

If we also make split_rngs={'params': True}, then the PRNG key we provide is split across the variable axis (in this case, the batch axis 0), and we can generate different parameters for each batch input:

BatchDense = nn.vmap(
    nn.Dense,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0},
    split_rngs={'params': True})

BatchDense(2).init(jax.random.key(0), x)
{'params': {'bias': Array([[0., 0.],
         [0., 0.],
         [0., 0.]], dtype=float32),
  'kernel': Array([[[-0.2526208 , -0.15088455],
          [-1.1987205 , -0.40843305]],
  
         [[-0.7064888 , -1.108805  ],
          [-0.938775  ,  1.4812315 ]],
  
         [[-0.59468937, -0.2502723 ],
          [-1.33515   ,  0.5067442 ]]], dtype=float32)}}

Adding a variable via self.variable is straightforward:

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(2)(x)
    kernel = self.variable(
      'other_collection',
      'kernel',
      lambda: jax.random.normal(self.make_rng('other'), x.shape),
    )
    return x + kernel.value

BatchModel = nn.vmap(
  Model,
  in_axes=0,
  out_axes=0,
  variable_axes={'params': 0, 'other_collection': 0},
  split_rngs={'params': True, 'other': True},
)

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([[-0.8193048 ,  0.711106  ],
         [-0.37802765, -0.66705877],
         [-0.44808003,  0.93031347]], dtype=float32)}}

We can control which RNG stream to split, for example, if we only wanted to split the 'params' RNG stream, then the variables generated from self.variable will be the same for each batch input:

BatchModel = nn.vmap(
    Model,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0, 'other_collection': 0},
    split_rngs={'params': True, 'other': False})

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([[ 0.44956833, -1.1854612 ],
         [ 0.44956833, -1.1854612 ],
         [ 0.44956833, -1.1854612 ]], dtype=float32)}}

We can also control which parameters / variables should be generated for each batch input, for example, if we only wanted 'params' to generate separate parameters for each batch input:

BatchModel = nn.vmap(
    Model,
    in_axes=0, out_axes=0,
    variable_axes={'params': 0, 'other_collection': None},
    split_rngs={'params': True, 'other': False})

BatchModel().init({'params': jax.random.key(0), 'other': jax.random.key(1)}, x)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.9079084 ,  0.76390624],
           [-0.01285526,  0.4320353 ]],
   
          [[ 0.12398645,  0.7884565 ],
           [ 1.5344163 ,  1.3186085 ]],
   
          [[-0.44171348,  0.43430036],
           [-0.40732604,  0.29774475]]], dtype=float32)}},
 'other_collection': {'kernel': Array([ 0.44956833, -1.1854612 ], dtype=float32)}}

nn.scan#

We can use nn.scan to create a scanned Module layer (this is useful for simplifying repetitively stacked submodules):

x = jnp.ones((3, 2))

class ResidualMLPBlock(nn.Module):
  @nn.compact
  def __call__(self, x, _):
    h = nn.Dense(features=2)(x)
    h = nn.relu(h)
    return x + h, None # return an empty carry

ScanMLP = nn.scan(
      ResidualMLPBlock, variable_axes={'params': 0},
      variable_broadcast=False, split_rngs={'params': True},
      length=3)

ScanMLP().init(jax.random.key(0), x, None) # pass in an empty carry
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.07838312, -0.7422982 ],
           [ 0.87488323,  0.13773395]],
   
          [[ 0.97309333,  0.9087693 ],
           [-0.12564984, -1.0920651 ]],
   
          [[-0.99055105,  1.1499453 ],
           [-0.15721127, -0.62520015]]], dtype=float32)}}}

Similar to before, we can control whether to split the RNG stream or not, for example, if we wanted all the stacked modules to be initialized to the same parameter values, we can pass in split_rngs={'params': False}:

ScanMLP = nn.scan(
      ResidualMLPBlock, variable_axes={'params': 0},
      variable_broadcast=False, split_rngs={'params': False},
      length=3)

ScanMLP().init(jax.random.key(0), x, None)
{'params': {'Dense_0': {'bias': Array([[0., 0.],
          [0., 0.],
          [0., 0.]], dtype=float32),
   'kernel': Array([[[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]],
   
          [[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]],
   
          [[-0.66715515, -0.0484313 ],
           [ 0.9867164 ,  0.75408363]]], dtype=float32)}}}