Recurrent#
RNN modules for Flax.
- class flax.nnx.nn.recurrent.LSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function modified_orthogonal>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, rngs)[source]#
LSTM cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- __call__(carry, inputs)[source]#
A long short-term memory (LSTM) cell.
- Parameters
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry
.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None)[source]#
Initialize the RNN cell carry.
- Parameters
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns
An initialized carry for the given RNN cell.
Methods
initialize_carry
(input_shape[, rngs])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.OptimizedLSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, rngs)[source]#
More efficient LSTM Cell that concatenates state components before matmul.
The parameters are compatible with
LSTMCell
. Note that this cell is often faster thanLSTMCell
as long as the hidden size is roughly <= 2048 units.The mathematical definition of the cell is the same as
LSTMCell
and as follows:\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- Parameters
gate_fn – activation function used for gates (default: sigmoid).
activation_fn – activation function used for output and memory update (default: tanh).
kernel_init – initializer function for the kernels that transform the input (default: lecun_normal).
recurrent_kernel_init – initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
bias_init – initializer for the bias parameters (default: initializers.zeros_init()).
dtype – the dtype of the computation (default: infer from inputs and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
- __call__(carry, inputs)[source]#
An optimized long short-term memory (LSTM) cell.
- Parameters
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry
.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None)[source]#
Initialize the RNN cell carry.
- Parameters
rngs – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns
An initialized carry for the given RNN cell.
Methods
initialize_carry
(input_shape[, rngs])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.SimpleCell(self, in_features, hidden_features, *, dtype=<class 'jax.numpy.float32'>, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, residual=False, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, rngs)[source]#
Simple cell.
The mathematical definition of the cell is as follows
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]where x is the input and h is the output of the previous time step.
If residual is True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]- __call__(carry, inputs)[source]#
Run the RNN cell.
- Parameters
carry – the hidden state of the RNN cell.
inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None)[source]#
Initialize the RNN cell carry.
- Parameters
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns
An initialized carry for the given RNN cell.
Methods
initialize_carry
(input_shape[, rngs])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.GRUCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, rngs)[source]#
GRU cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]where x is the input and h is the output of the previous time step.
- Parameters
in_features – number of input features.
hidden_features – number of output features.
gate_fn – activation function used for gates (default: sigmoid).
activation_fn – activation function used for output and memory update (default: tanh).
kernel_init – initializer function for the kernels that transform the input (default: lecun_normal).
recurrent_kernel_init – initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
bias_init – initializer for the bias parameters (default: initializers.zeros_init()).
dtype – the dtype of the computation (default: None).
param_dtype – the dtype passed to parameter initializers (default: float32).
- __call__(carry, inputs)[source]#
Gated recurrent unit (GRU) cell.
- Parameters
carry – the hidden state of the GRU cell, initialized using
GRUCell.initialize_carry
.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None)[source]#
Initialize the RNN cell carry.
- Parameters
rngs – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns
An initialized carry for the given RNN cell.
Methods
initialize_carry
(input_shape[, rngs])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.RNN(self, cell, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, rngs=None, state_axes=None, broadcast_rngs=None)[source]#
The
RNN
module takes anyRNNCellBase
instance and applies it over a sequenceusing
flax.nnx.scan()
.- __call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[source]#
Call self as a function.
Methods
- class flax.nnx.nn.recurrent.Bidirectional(self, forward_rnn, backward_rnn, *, merge_fn=<function _concatenate>, time_major=False, return_carry=False, rngs=None)[source]#
Processes the input in both directions and merges the results.
Example usage:
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp >>> # Define forward and backward RNNs >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> # Create Bidirectional layer >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) >>> # Input data >>> x = jnp.ones((2, 3, 3)) >>> # Apply the layer >>> out = layer(x) >>> print(out.shape) (2, 3, 8)
- __call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Call self as a function.
Methods
- flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[source]#
Flips a sequence of inputs along the time axis.
This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements.
Example:
>>> from flax.nnx.nn.recurrent import flip_sequences >>> from jax import numpy as jnp >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) >>> lengths = jnp.array([1, 2, 3]) >>> flip_sequences(inputs, lengths, 1, False) Array([[1, 0, 0], [3, 2, 0], [6, 5, 4]], dtype=int32)
- Parameters
inputs – An array of input IDs <int>[batch_size, seq_length].
lengths – The length of each sequence <int>[batch_size].
- Returns
An ndarray with the flipped inputs.