Open In Colab Open On GitHub

Preface

CAVEAT PROGRAMMER

The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points.

Install and Import

[ ]:
# Install the newest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
[2]:
import functools
from typing import Any, Callable, Sequence, Optional
import numpy as np
import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

Invoking Modules

Let’s instantiate a Dense layer. - Modules are actually objects in this API, so we provide contructor arguments when initializing the Module. In this case, we only have to provide the output features dimension.

[3]:
model = nn.Dense(features=3)

We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables.

We call the init method on the instantiated Module. If the Module __call__ method has args (self, *args, **kwargs) then we call init with (rngs, *args, **kwargs) so in this case, just (rng, input):

[4]:
# Make RNG Keys and a fake input.
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

# provide key and fake input to get initialized variables
init_variables = model.init(key2, x)

init_variables
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
                     [ 0.05673932,  0.9909285 , -0.63536596],
                     [ 0.76134115, -0.3250529 , -0.6522163 ],
                     [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})

We call the apply method on the instantiated Module. If the Module __call__ method has args (self, *args, **kwargs) then we call apply with (variables, *args, rngs=<RNGS>, mutable=<MUTABLEKINDS>, **kwargs) where - <RNGS> are the optional call time RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple kinds of data, it’s a dictionary of rng-keys per-kind, e.g. {'params': key0, 'dropout': key1} for a Module with dropout layers. - <MUTABLEKINDS> is an optional list of names of kinds that are expected to be mutated during the call. e.g. ['batch_stats'] for a layer updating batchnorm statistics.

So in this case, just (variables, input):

[5]:
y = model.apply(init_variables, x)
y
[5]:
DeviceArray([[ 0.5035518 ,  1.8548559 , -0.4270196 ],
             [ 0.0279097 ,  0.5589246 , -0.43061775],
             [ 0.35471284,  1.5741    , -0.3286552 ],
             [ 0.5264864 ,  1.2928858 ,  0.10089308]], dtype=float32)

Additional points: - If you want to init or apply a Module using a method other than call, you need to provide the method= kwarg to init and apply to use it instead of the default __call__, e.g. method='encode', method='decode' to apply the encode/decode methods of an autoencoder.

Defining Basic Modules

Composing submodules

We support declaring modules in setup() that can still benefit from shape inference by using Lazy Initialization that sets up variables the first time the Module is called.

[6]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024090e-04  2.7864411e-05  2.4478839e-04  8.1344356e-04
  -1.0110775e-03]]

Here we show the equivalent compact form of the MLP that declares the submodules inline using the @compact decorator.

[7]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
      # x = nn.Dense(feat)(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292815e-02 -4.3807115e-02  2.9323792e-02  6.5492536e-03
  -1.7147182e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024090e-04  2.7864411e-05  2.4478839e-04  8.1344356e-04
  -1.0110775e-03]]

Declaring and using variables

Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls.

For declaring parameters that aren’t mutated inside the model, but rather by gradient descent, we use the syntax:

self.param(parameter_name, parameter_init_fn, *init_args)

with arguments: - parameter_name just the name, a string - parameter_init_fn a function taking an RNG key and a variable number of other arguments, i.e. fn(rng, *args). typically those in nn.initializers take an rng and a shape argument. - the remaining arguments to feed to the init function when initializing.

Again, we’ll demonstrate declaring things inline as we typically do using the @compact decorator.

[8]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init,  # RNG passed implicitly.
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),)
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameters:\n', init_variables)
print('output:\n', y)
initialized parameters:
 FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
                     [ 0.05673932,  0.9909285 , -0.63536596],
                     [ 0.76134115, -0.3250529 , -0.6522163 ],
                     [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})
output:
 [[ 0.5035518   1.8548559  -0.4270196 ]
 [ 0.0279097   0.5589246  -0.43061775]
 [ 0.35471284  1.5741     -0.3286552 ]
 [ 0.5264864   1.2928858   0.10089308]]

We can also declare variables in setup, though in doing so you can’t take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names.

[9]:
class ExplicitDense(nn.Module):
  features_in: int  # <-- explicit input shape
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  def setup(self):
    self.kernel = self.param('kernel',
                             self.kernel_init,
                             (self.features_in, self.features))
    self.bias = self.param('bias', self.bias_init, (self.features,))

  def __call__(self, inputs):
    y = lax.dot_general(inputs, self.kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),)
    y = y + self.bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitDense(features_in=4, features=3)
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameters:\n', init_variables)
print('output:\n', y)
initialized parameters:
 FrozenDict({
    params: {
        kernel: DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
                     [ 0.05673932,  0.9909285 , -0.63536596],
                     [ 0.76134115, -0.3250529 , -0.6522163 ],
                     [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32),
        bias: DeviceArray([0., 0., 0.], dtype=float32),
    },
})
output:
 [[ 0.5035518   1.8548559  -0.4270196 ]
 [ 0.0279097   0.5589246  -0.43061775]
 [ 0.35471284  1.5741     -0.3286552 ]
 [ 0.5264864   1.2928858   0.10089308]]

General Variables

For declaring generally mutable variables that may be mutated inside the model we use the call:

self.variable(variable_kind, variable_name, variable_init_fn, *init_args)

with arguments: - variable_kind the “kind” of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. batch_stats for the moving statistics for a batch norm layer or cache for autoregressive cache data. Note that parameters also have a kind, but they’re set to the default param kind. - variable_name just the name, a string - variable_init_fn a function taking a variable number of other arguments, i.e. fn(*args). Note that we don’t assume the need for an RNG, if you do want an RNG, provide it via a self.make_rng(variable_kind) call in the provided arguments. - the remaining arguments to feed to the init function when initializing.

⚠️ Unlike parameters, we expect these to be mutated, so self.variable returns not a constant, but a reference to the variable. To get the raw value, you’d write myvariable.value and to set it myvariable.value = new_value.

[10]:
class Counter(nn.Module):
  @nn.compact
  def __call__(self):
    # easy pattern to detect if we're initializing
    is_initialized = self.has_variable('counter', 'count')
    counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))
    if is_initialized:
      counter.value += 1
    return counter.value


key1 = random.PRNGKey(0)

model = Counter()
init_variables = model.init(key1)
print('initialized variables:\n', init_variables)

y, mutated_variables = model.apply(init_variables, mutable=['counter'])

print('mutated variables:\n', mutated_variables)
print('output:\n', y)
initialized variables:
 FrozenDict({
    counter: {
        count: DeviceArray(0, dtype=int32),
    },
})
mutated variables:
 FrozenDict({
    counter: {
        count: DeviceArray(1, dtype=int32),
    },
})
output:
 1

Another Mutability and RNGs Example

Let’s make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables:

[11]:
class Block(nn.Module):
  features: int
  training: bool
  @nn.compact
  def __call__(self, inputs):
    x = nn.Dense(self.features)(inputs)
    x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)
    x = nn.BatchNorm(use_running_average=not self.training)(x)
    return x

key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)
x = random.uniform(key1, (3,4,4))

model = Block(features=3, training=True)

init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = init_variables.pop('params')

# When calling `apply` with mutable kinds, returns a pair of output,
# mutated_variables.
y, mutated_variables = model.apply(
    init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])

# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = freeze(dict(params=init_params,
                                **mutated_variables))

print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
      jax.tree_map(jnp.shape, init_variables))
print('output:\n', y)

# Let's run these model variables during "evaluation":
eval_model = Block(features=3, training=False)
y = eval_model.apply(updated_variables, x)  # Nothing mutable; single return value.
print('eval output:\n', y)

updated variables:
 FrozenDict({
    params: {
        Dense_0: {
            kernel: DeviceArray([[ 0.6498898 , -0.5000124 ,  0.78573596],
                         [-0.25609785, -0.7132329 ,  0.2500864 ],
                         [-0.64630085,  0.39321756, -1.0203307 ],
                         [ 0.38721725,  0.86828285,  0.10860055]], dtype=float32),
            bias: DeviceArray([0., 0., 0.], dtype=float32),
        },
        BatchNorm_0: {
            scale: DeviceArray([1., 1., 1.], dtype=float32),
            bias: DeviceArray([0., 0., 0.], dtype=float32),
        },
    },
    batch_stats: {
        BatchNorm_0: {
            mean: DeviceArray([ 0.00059601, -0.00103457,  0.00166948], dtype=float32),
            var: DeviceArray([0.9907686, 0.9923046, 0.992195 ], dtype=float32),
        },
    },
})
initialized variable shapes:
 FrozenDict({
    batch_stats: {
        BatchNorm_0: {
            mean: (3,),
            var: (3,),
        },
    },
    params: {
        BatchNorm_0: {
            bias: (3,),
            scale: (3,),
        },
        Dense_0: {
            bias: (3,),
            kernel: (4, 3),
        },
    },
})
output:
 [[[-0.21496922  0.21550177 -0.35633382]
  [-0.21496922 -2.0458      1.3015485 ]
  [-0.21496922 -0.925116   -0.35633382]
  [-0.6595459   0.21550177  0.3749205 ]]

 [[-0.21496922  1.642865   -0.35633382]
  [-0.21496922  1.3094063  -0.88034123]
  [ 2.5726683   0.21550177  0.34353197]
  [-0.21496922  0.21550177  1.6778195 ]]

 [[-1.6060593   0.21550177 -1.9460517 ]
  [ 1.4126908  -1.4898677   1.2790381 ]
  [-0.21496922  0.21550177 -0.35633382]
  [-0.21496922  0.21550177 -0.7251308 ]]]
eval output:
 [[[ 3.2246590e-01  2.6108384e-02  4.4821960e-01]
  [ 8.5726947e-02 -5.4385906e-01  3.8821870e-01]
  [-2.3933809e-01 -2.7381191e-01 -1.7526165e-01]
  [-6.2515378e-02 -5.2414006e-01  1.7029770e-01]]

 [[ 1.5014435e-01  3.4498507e-01 -1.3554120e-01]
  [-3.6971044e-04  2.6463276e-01 -1.2491019e-01]
  [ 3.8763803e-01  2.9023719e-01  1.6291586e-01]
  [ 4.1320035e-01  4.1468274e-02  4.7670874e-01]]

 [[-1.9433719e-01  5.2831882e-01 -3.7554008e-01]
  [ 2.2608691e-01 -4.0989807e-01  3.8292480e-01]
  [-2.4945706e-01  1.6170470e-01 -2.5247774e-01]
  [-7.2220474e-02  1.2077977e-01 -8.8408351e-02]]]

JAX transformations inside modules

JIT

It’s not immediately clear what use this has, but you can compile specific submodules if there’s a reason to.

Known Gotcha: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different.

[12]:
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      # JIT the Module (it's __call__ fn by default.)
      x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4,4))

model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.2524199   0.11621253  0.5246693   0.19144788  0.2096542 ]
 [ 0.08557513 -0.04126885  0.2502836   0.03910369  0.16575359]
 [ 0.2804383   0.27751124  0.44969672  0.26016283  0.05875347]
 [ 0.2440843   0.17069656  0.45499086  0.20377949  0.13428023]]

Remat

For memory-expensive computations, we can remat our method to recompute a Module’s output during a backwards pass.

Known Gotcha: at the moment, the decorator changes the RNG stream slightly, so comparing remat’d and undecorated initializations will look different.

[13]:
class RematMLP(nn.Module):
  features: Sequence[int]
  # For all transforms, we can annotate a method, or wrap an existing
  # Module class. Here we annotate the method.
  @nn.remat
  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(3), 2)
x = random.uniform(key1, (4,4))

model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[-0.14814317  0.06889858 -0.19695625  0.12019286  0.02068037]
 [-0.04439102 -0.06698258 -0.11579747 -0.19906905 -0.04342325]
 [-0.08875751 -0.13392815 -0.23153095 -0.39802808 -0.0868225 ]
 [-0.01606487 -0.02424064 -0.04190649 -0.07204203 -0.01571464]]

Vmap

You can now vmap Modules inside. The transform has a lot of arguments, they have the usual jax vmap args: - in_axes - an integer or None for each input argument - out_axes - an integer or None for each output argument - axis_size - the axis size if you need to give it explicitly

In addition, we provide for each kind of variable it’s axis rules:

  • variable_in_axes - a dict from kinds to a single integer or None specifying the input axes to map

  • variable_out_axes - a dict from kinds to a single integer or None specifying the output axes to map

  • split_rngs - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis.

Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation.

[14]:
class RawDotProductAttention(nn.Module):
  attn_dropout_rate: float = 0.1
  train: bool = False

  @nn.compact
  def __call__(self, query, key, value, bias=None, dtype=jnp.float32):
    assert key.ndim == query.ndim
    assert key.ndim == value.ndim

    n = query.ndim
    attn_weights = lax.dot_general(
        query, key,
        (((n-1,), (n - 1,)), ((), ())))
    if bias is not None:
      attn_weights += bias
    norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim))
    attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims)
    attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights,
                                                      deterministic=not self.train)
    attn_weights = attn_weights.astype(dtype)

    contract_dims = (
        tuple(range(n - 1, attn_weights.ndim)),
        tuple(range(0, n  - 1)))
    y = lax.dot_general(
        attn_weights, value,
        (contract_dims, ((), ())))
    return y

class DotProductAttention(nn.Module):
  qkv_features: Optional[int] = None
  out_features: Optional[int] = None
  train: bool = False

  @nn.compact
  def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
    qkv_features = self.qkv_features or inputs_q.shape[-1]
    out_features = self.out_features or inputs_q.shape[-1]

    QKVDense = functools.partial(
      nn.Dense, features=qkv_features, use_bias=False, dtype=dtype)
    query = QKVDense(name='query')(inputs_q)
    key = QKVDense(name='key')(inputs_kv)
    value = QKVDense(name='value')(inputs_kv)

    y = RawDotProductAttention(train=self.train)(
        query, key, value, bias=bias, dtype=dtype)

    y = nn.Dense(features=out_features, dtype=dtype, name='out')(y)
    return y

class MultiHeadDotProductAttention(nn.Module):
  qkv_features: Optional[int] = None
  out_features: Optional[int] = None
  batch_axes: Sequence[int] = (0,)
  num_heads: int = 1
  broadcast_dropout: bool = False
  train: bool = False
  @nn.compact
  def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):
    qkv_features = self.qkv_features or inputs_q.shape[-1]
    out_features = self.out_features or inputs_q.shape[-1]

    # Make multiheaded attention from single-headed dimension.
    Attn = nn.vmap(DotProductAttention,
                   in_axes=(None, None, None),
                   out_axes=2,
                   axis_size=self.num_heads,
                   variable_axes={'params': 0},
                   split_rngs={'params': True,
                               'dropout': not self.broadcast_dropout})

    # Vmap across batch dimensions.
    for axis in reversed(sorted(self.batch_axes)):
      Attn = nn.vmap(Attn,
                     in_axes=(axis, axis, axis),
                     out_axes=axis,
                     variable_axes={'params': None},
                     split_rngs={'params': False, 'dropout': False})

    # Run the vmap'd class on inputs.
    y = Attn(qkv_features=qkv_features // self.num_heads,
             out_features=out_features,
             train=self.train,
             name='attention')(inputs_q, inputs_kv, bias)

    return y.mean(axis=-2)


key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)
x = random.uniform(key1, (3, 13, 64))

model = functools.partial(
  MultiHeadDotProductAttention,
  broadcast_dropout=False,
  num_heads=2,
  batch_axes=(0,))

init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))

y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
initialized parameter shapes:
 {'params': {'attention': {'key': {'kernel': (2, 64, 32)}, 'out': {'bias': (2, 64), 'kernel': (2, 32, 64)}, 'query': {'kernel': (2, 64, 32)}, 'value': {'kernel': (2, 64, 32)}}}}
output:
 (3, 13, 2)

Scan

Scan allows us to apply lax.scan to Modules, including their parameters and mutable variables. To use it we have to specify how we want each “kind” of variable to be transformed. For scanned variables we specify similar to vmap via in variable_in_axes, variable_out_axes: - nn.broadcast broadcast the variable kind across the scan steps as a constant - <axis:int> scan along axis for e.g. unique parameters at each step

OR we specify that the variable kind is to be treated like a “carry” by passing to the variable_carry argument.

Further, for scan’d variable kinds, we further specify whether or not to split the rng at each step.

[15]:
class SimpleScan(nn.Module):
  @nn.compact
  def __call__(self, xs):
    dummy_rng = random.PRNGKey(0)
    init_carry = nn.LSTMCell.initialize_carry(dummy_rng,
                                              xs.shape[:1],
                                              xs.shape[-1])
    LSTM = nn.scan(nn.LSTMCell,
                   in_axes=1, out_axes=1,
                   variable_broadcast='params',
                   split_rngs={'params': False})
    return LSTM(name="lstm_cell")(init_carry, xs)

key1, key2 = random.split(random.PRNGKey(0), 2)
xs = random.uniform(key1, (1, 5, 2))

model = SimpleScan()
init_variables = model.init(key2, xs)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))

y = model.apply(init_variables, xs)
print('output:\n', y)
initialized parameter shapes:
 {'params': {'lstm_cell': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}}
output:
 ((DeviceArray([[-0.562219  ,  0.92847174]], dtype=float32), DeviceArray([[-0.31570646,  0.2885693 ]], dtype=float32)), DeviceArray([[[-0.08265854,  0.01302483],
              [-0.10249066,  0.21991298],
              [-0.26609066,  0.22519003],
              [-0.27982554,  0.28393182],
              [-0.31570646,  0.2885693 ]]], dtype=float32))
[ ]: