flax.linen.LSTMCell#
- class flax.linen.LSTMCell(*args, **kwds)[source]#
LSTM cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- features#
number of output features.
- Type
int
- gate_fn#
activation function used for gates (default: sigmoid)
- Type
Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type
Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type
jax.nn.initializers.Initializer
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type
jax.nn.initializers.Initializer
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init())
- Type
jax.nn.initializers.Initializer
- dtype#
the dtype of the computation (default: infer from inputs and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- __call__(carry, inputs)[source]#
A long short-term memory (LSTM) cell.
- Parameters
carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.
inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
Methods
initialize_carry
(**kwargs)Initialize the RNN cell carry.