flax.linen.OptimizedLSTMCell

class flax.linen.OptimizedLSTMCell(gate_fn=<CompiledFunction of <function sigmoid>>, activation_fn=<CompiledFunction of <function _one_to_one_unop.<locals>.<lambda>>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

More efficient LSTM Cell that concatenates state components before matmul.

The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.

The mathematical definition of the cell is the same as LSTMCell and as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

Parameters
  • gate_fn (Callable[[...], Any]) –

  • activation_fn (Callable[[...], Any]) –

  • kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • recurrent_kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int], Any], Any]) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

gate_fn

activation function used for gates (default: sigmoid).

Type

Callable[[…], Any]

activation_fn

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init

initializer function for the kernels that transform the input (default: lecun_normal).

Type

Callable[[Any, Tuple[int], Any], Any]

recurrent_kernel_init

initializer function for the kernels that transform the hidden state (default: orthogonal).

Type

Callable[[Any, Tuple[int], Any], Any]

bias_init

initializer for the bias parameters (default: zeros).

Type

Callable[[Any, Tuple[int], Any], Any]

dtype

the dtype of the computation (default: infer from inputs and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]

An optimized long short-term memory (LSTM) cell.

Parameters
  • carry (Tuple[Any, Any]) – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs (Any) – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Return type

Tuple[Tuple[Any, Any], Any]

Methods

activation_fn()

Compute hyperbolic tangent element-wise.

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

gate_fn()

Sigmoid activation function.

initialize_carry(rng, batch_dims, size[, ...])

Initialize the RNN cell carry.

kernel_init(shape[, dtype])

recurrent_kernel_init(shape[, dtype])