flax.linen.RNN#

class flax.linen.RNN(cell, cell_size=<flax.linen.recurrent._Never object>, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({     params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

The RNN module takes any RNNCellBase instance and applies it over a sequence using flax.linen.scan().

Example:

>>> import jax.numpy as jnp
>>> import jax
>>> import flax.linen as nn
...
>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64))
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

As shown above, RNN uses the cell_size argument to set the size argument for the cell’s initialize_carry method, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example the ConvLSTMCell requires a size argument of the form (kernel_height, kernel_width, features):

>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features)
>>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3)))
>>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x)
>>> y.shape # (batch, time, height, width, features)
(10, 50, 32, 32, 64)

By default RNN expect the time dimension after the batch dimension ((*batch, time, *features)), if you set time_major=True RNN will instead expect the time dimesion to be at the beginning ((time, *batch, *features)):

>>> x = jnp.ones((50, 10, 32)) # (time, batch, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (time, batch, cell_size)
(50, 10, 64)

The output is an array of shape (*batch, time, *cell_size) by default (typically), however if you set return_carry=True it will instead return a tuple of the final carry and the output:

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> carry, y = lstm.apply(variables, x)
>>> jax.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size))
((10, 64), (10, 64))
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

To support variable length sequences, you can pass a seq_lengths which is an integer array of shape (*batch) where each element is the length of the sequence in the batch. For example:

>>> seq_lengths = jnp.array([3, 2, 5])

The output elements corresponding to padding elements are NOT zeroed out. If return_carry is set to True the carry will be the state of the last valid element of each sequence.

RNN also accepts some of the arguments of flax.linen.scan(), by default they are set to work with cells like LSTMCell and GRUCell but they can be overriden as needed. Overriding default values to scan looks like this:

>>> lstm = nn.RNN(
...   nn.LSTMCell(64),
...   unroll=1, variable_axes={}, variable_broadcast='params',
...   variable_carry=False, split_rngs={'params': False})
cell#

an instance of RNNCellBase.

Type

flax.linen.recurrent.RNNCellBase

time_major#

if time_major=False (default) it will expect inputs with shape (*batch, time, *features), else it will expect inputs with shape (time, *batch, *features).

Type

bool

return_carry#

if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

Type

bool

reverse#

if reverse=False (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If seq_lengths is passed, padding will always remain at the end of the sequence.

Type

bool

keep_order#

if keep_order=True, when reverse=True the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If keep_order=False (default), the output will remain in the order specified by reverse.

Type

bool

unroll#

how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to nn.scan.

Type

int

variable_axes#

a dictionary mapping each collection to either an integer i (meaning we scan over dimension i) or None (replicate rather than scan). This argument is forwarded to nn.scan.

Type

Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.core.lift.In[int], flax.core.lift.Out[int]]]

variable_broadcast#

Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to nn.scan.

Type

Union[bool, str, Collection[str], DenyList]

variable_carry#

Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to nn.scan.

Type

Union[bool, str, Collection[str], DenyList]

split_rngs#

a mapping from PRNGSequenceFilter to bool specifying whether a collection’s PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to nn.scan.

Type

Mapping[Union[bool, str, Collection[str], DenyList], bool]

__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

Applies the RNN to the inputs.

__call__ allows you to optionally override some attributes like return_carry and time_major defined in the constructor.

Parameters
  • inputs – the input sequence.

  • initial_carry – the initial carry, if not provided it will be initialized using the cell’s RNNCellBase.initialize_carry() method.

  • init_key – a PRNG key used to initialize the carry, if not provided jax.random.key(0) will be used. Most cells will ignore this argument.

  • seq_lengths – an optional integer array of shape (*batch) indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored.

  • return_carry – if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

  • time_major – if time_major=False (default) it will expect inputs with shape (*batch, time, *features), else it will expect inputs with shape (time, *batch, *features).

  • reverse – overrides the reverse attribute, if reverse=False (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If seq_lengths is passed, padding will always remain at the end of the sequence.

  • keep_order – overrides the keep_order attribute, if keep_order=True, when reverse=True the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If keep_order=False (default), the output will remain in the order specified by reverse.

Returns

if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

Methods