
RNN modules for Flax.

class flax.nnx.nn.recurrent.LSTMCell(*args, **kwargs)[source]#

LSTM cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

__call__(carry, inputs)[source]#

A long short-term memory (LSTM) cell.

  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.


A tuple with the new carry and the output.

initialize_carry(input_shape, rngs=None)[source]#

Initialize the RNN cell carry.

  • rng – random number generator passed to the init_fn.

  • input_shape – a tuple providing the shape of the input to the cell.


An initialized carry for the given RNN cell.


initialize_carry(input_shape[, rngs])

Initialize the RNN cell carry.

class flax.nnx.nn.recurrent.OptimizedLSTMCell(*args, **kwargs)[source]#

More efficient LSTM Cell that concatenates state components before matmul.

The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.

The mathematical definition of the cell is the same as LSTMCell and as follows:

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.


activation function used for gates (default: sigmoid).


activation function used for output and memory update (default: tanh).


initializer function for the kernels that transform the input (default: lecun_normal).


initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).


initializer for the bias parameters (default: initializers.zeros_init()).


the dtype of the computation (default: infer from inputs and params).


the dtype passed to parameter initializers (default: float32).

__call__(carry, inputs)[source]#

An optimized long short-term memory (LSTM) cell.

  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.


A tuple with the new carry and the output.

initialize_carry(input_shape, rngs=None)[source]#

Initialize the RNN cell carry.

  • rngs – random number generator passed to the init_fn.

  • input_shape – a tuple providing the shape of the input to the cell.


An initialized carry for the given RNN cell.


initialize_carry(input_shape[, rngs])

Initialize the RNN cell carry.

class flax.nnx.nn.recurrent.SimpleCell(*args, **kwargs)[source]#

Simple cell.

The mathematical definition of the cell is as follows

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]

where x is the input and h is the output of the previous time step.

If residual is True,

\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]
__call__(carry, inputs)[source]#

Run the RNN cell.

  • carry – the hidden state of the RNN cell.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.


A tuple with the new carry and the output.

initialize_carry(input_shape, rngs=None)[source]#

Initialize the RNN cell carry.

  • rng – random number generator passed to the init_fn.

  • input_shape – a tuple providing the shape of the input to the cell.


An initialized carry for the given RNN cell.


initialize_carry(input_shape[, rngs])

Initialize the RNN cell carry.

class flax.nnx.nn.recurrent.GRUCell(*args, **kwargs)[source]#

GRU cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

where x is the input and h is the output of the previous time step.


number of input features.


number of output features.


activation function used for gates (default: sigmoid).


activation function used for output and memory update (default: tanh).


initializer function for the kernels that transform the input (default: lecun_normal).


initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).


initializer for the bias parameters (default: initializers.zeros_init()).


the dtype of the computation (default: None).


the dtype passed to parameter initializers (default: float32).

__call__(carry, inputs)[source]#

Gated recurrent unit (GRU) cell.

  • carry – the hidden state of the GRU cell, initialized using GRUCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.


A tuple with the new carry and the output.

initialize_carry(input_shape, rngs=None)[source]#

Initialize the RNN cell carry.

  • rngs – random number generator passed to the init_fn.

  • input_shape – a tuple providing the shape of the input to the cell.


An initialized carry for the given RNN cell.


initialize_carry(input_shape[, rngs])

Initialize the RNN cell carry.

class flax.nnx.nn.recurrent.RNN(*args, **kwargs)[source]#

The RNN module takes any RNNCellBase instance and applies it over a sequence

using flax.nnx.scan().

__call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[source]#

Call self as a function.


class flax.nnx.nn.recurrent.Bidirectional(*args, **kwargs)[source]#

Processes the input in both directions and merges the results.

Example usage:

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp

>>> # Define forward and backward RNNs
>>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))
>>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0)))

>>> # Create Bidirectional layer
>>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn)

>>> # Input data
>>> x = jnp.ones((2, 3, 3))

>>> # Apply the layer
>>> out = layer(x)
>>> print(out.shape)
(2, 3, 8)
__call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

Call self as a function.


flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[source]#

Flips a sequence of inputs along the time axis.

This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements.


>>> from flax.nnx.nn.recurrent import flip_sequences
>>> from jax import numpy as jnp
>>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])
>>> lengths = jnp.array([1, 2, 3])
>>> flip_sequences(inputs, lengths, 1, False)
Array([[1, 0, 0],
       [3, 2, 0],
       [6, 5, 4]], dtype=int32)
  • inputs – An array of input IDs <int>[batch_size, seq_length].

  • lengths – The length of each sequence <int>[batch_size].


An ndarray with the flipped inputs.