flax.linen.RNN#
- class flax.linen.RNN(cell, cell_size=<flax.linen.recurrent._Never object>, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({ params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#
The
RNN
module takes anyRNNCellBase
instance and applies it over a sequence usingflax.linen.scan()
.Example:
>>> import jax.numpy as jnp >>> import jax >>> import flax.linen as nn ... >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64)) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
As shown above, RNN uses the
cell_size
argument to set thesize
argument for the cell’sinitialize_carry
method, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example theConvLSTMCell
requires asize
argument of the form(kernel_height, kernel_width, features)
:>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3))) >>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x) >>> y.shape # (batch, time, height, width, features) (10, 50, 32, 32, 64)
By default RNN expect the time dimension after the batch dimension (
(*batch, time, *features)
), if you settime_major=True
RNN will instead expect the time dimesion to be at the beginning ((time, *batch, *features)
):>>> x = jnp.ones((50, 10, 32)) # (time, batch, features) >>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True) >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (time, batch, cell_size) (50, 10, 64)
The output is an array of shape
(*batch, time, *cell_size)
by default (typically), however if you setreturn_carry=True
it will instead return a tuple of the final carry and the output:>>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True) >>> variables = lstm.init(jax.random.key(0), x) >>> carry, y = lstm.apply(variables, x) >>> jax.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) ((10, 64), (10, 64)) >>> y.shape # (batch, time, cell_size) (10, 50, 64)
To support variable length sequences, you can pass a
seq_lengths
which is an integer array of shape(*batch)
where each element is the length of the sequence in the batch. For example:>>> seq_lengths = jnp.array([3, 2, 5])
The output elements corresponding to padding elements are NOT zeroed out. If
return_carry
is set toTrue
the carry will be the state of the last valid element of each sequence.RNN also accepts some of the arguments of
flax.linen.scan()
, by default they are set to work with cells likeLSTMCell
andGRUCell
but they can be overriden as needed. Overriding default values to scan looks like this:>>> lstm = nn.RNN( ... nn.LSTMCell(64), ... unroll=1, variable_axes={}, variable_broadcast='params', ... variable_carry=False, split_rngs={'params': False})
- cell#
an instance of
RNNCellBase
.
- time_major#
if
time_major=False
(default) it will expect inputs with shape(*batch, time, *features)
, else it will expect inputs with shape(time, *batch, *features)
.- Type
bool
- return_carry#
if
return_carry=False
(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.- Type
bool
- reverse#
if
reverse=False
(default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. Ifseq_lengths
is passed, padding will always remain at the end of the sequence.- Type
bool
- keep_order#
if
keep_order=True
, whenreverse=True
the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. Ifkeep_order=False
(default), the output will remain in the order specified byreverse
.- Type
bool
- unroll#
how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to nn.scan.
- Type
int
- variable_axes#
a dictionary mapping each collection to either an integer i (meaning we scan over dimension i) or None (replicate rather than scan). This argument is forwarded to nn.scan.
- Type
Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.core.lift.In[int], flax.core.lift.Out[int]]]
- variable_broadcast#
Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to nn.scan.
- Type
Union[bool, str, Collection[str], DenyList]
- variable_carry#
Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to nn.scan.
- Type
Union[bool, str, Collection[str], DenyList]
- split_rngs#
a mapping from PRNGSequenceFilter to bool specifying whether a collection’s PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to nn.scan.
- Type
Mapping[Union[bool, str, Collection[str], DenyList], bool]
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Applies the RNN to the inputs.
__call__
allows you to optionally override some attributes likereturn_carry
andtime_major
defined in the constructor.- Parameters
inputs – the input sequence.
initial_carry – the initial carry, if not provided it will be initialized using the cell’s
RNNCellBase.initialize_carry()
method.init_key – a PRNG key used to initialize the carry, if not provided
jax.random.key(0)
will be used. Most cells will ignore this argument.seq_lengths – an optional integer array of shape
(*batch)
indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored.return_carry – if
return_carry=False
(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.time_major – if
time_major=False
(default) it will expect inputs with shape(*batch, time, *features)
, else it will expect inputs with shape(time, *batch, *features)
.reverse – overrides the
reverse
attribute, ifreverse=False
(default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. Ifseq_lengths
is passed, padding will always remain at the end of the sequence.keep_order – overrides the
keep_order
attribute, ifkeep_order=True
, whenreverse=True
the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. Ifkeep_order=False
(default), the output will remain in the order specified byreverse
.
- Returns
if
return_carry=False
(default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.
Methods