flax.linen.OptimizedLSTMCell#

class flax.linen.OptimizedLSTMCell(gate_fn=<CompiledFunction of <function sigmoid>>, activation_fn=<CompiledFunction of <function jax.numpy.tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

More efficient LSTM Cell that concatenates state components before matmul.

The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.

The mathematical definition of the cell is the same as LSTMCell and as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

gate_fn#

activation function used for gates (default: sigmoid).

Type

Callable[[…], Any]

activation_fn#

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

Type

Callable[[Any, Tuple[int, …], Any], Any]

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: orthogonal).

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer for the bias parameters (default: zeros).

Type

Callable[[Any, Tuple[int, …], Any], Any]

dtype#

the dtype of the computation (default: infer from inputs and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]#

An optimized long short-term memory (LSTM) cell.

Parameters
  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

initialize_carry(rng, batch_dims, size[, ...])

Initialize the RNN cell carry.