flax.linen.OptimizedLSTMCell#
- class flax.linen.OptimizedLSTMCell(*args, **kwds)[source]#
More efficient LSTM Cell that concatenates state components before matmul.
The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.
The mathematical definition of the cell is the same as LSTMCell and as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- gate_fn#
activation function used for gates (default: sigmoid).
- Type
Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type
Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type
jax.nn.initializers.Initializer
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type
jax.nn.initializers.Initializer
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init()).
- Type
jax.nn.initializers.Initializer
- dtype#
the dtype of the computation (default: infer from inputs and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- __call__(carry, inputs)[source]#
An optimized long short-term memory (LSTM) cell.
- Parameters
carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.
inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
Methods
initialize_carry
(**kwargs)Initialize the RNN cell carry.