flax.linen.LSTMCell

class flax.linen.LSTMCell(gate_fn=<CompiledFunction of <function sigmoid>>, activation_fn=<CompiledFunction of <function _one_to_one_unop.<locals>.<lambda>>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

LSTM cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

Parameters
  • gate_fn (Callable[[...], Any]) –

  • activation_fn (Callable[[...], Any]) –

  • kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • recurrent_kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int], Any], Any]) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

gate_fn

activation function used for gates (default: sigmoid)

Type

Callable[[…], Any]

activation_fn

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init

initializer function for the kernels that transform the input (default: lecun_normal).

Type

Callable[[Any, Tuple[int], Any], Any]

recurrent_kernel_init

initializer function for the kernels that transform the hidden state (default: orthogonal).

Type

Callable[[Any, Tuple[int], Any], Any]

bias_init

initializer for the bias parameters (default: zeros)

Type

Callable[[Any, Tuple[int], Any], Any]

dtype

the dtype of the computation (default: infer from inputs and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]

A long short-term memory (LSTM) cell.

Parameters
  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

activation_fn()

Compute hyperbolic tangent element-wise.

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

gate_fn()

Sigmoid activation function.

initialize_carry(rng, batch_dims, size[, ...])

Initialize the RNN cell carry.

kernel_init(shape[, dtype])

recurrent_kernel_init(shape[, dtype])