flax.linen.GRUCell#
- class flax.linen.GRUCell(features, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function jax.numpy.tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
GRU cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]where x is the input and h is the output of the previous time step.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (2, 3)) >>> layer = nn.GRUCell(features=4) >>> carry = layer.initialize_carry(jax.random.key(1), x.shape) >>> variables = layer.init(jax.random.key(2), carry, x) >>> new_carry, out = layer.apply(variables, carry, x)
- features#
number of output features.
- Type
int
- gate_fn#
activation function used for gates (default: sigmoid).
- Type
Callable[[…], Any]
- activation_fn#
activation function used for output and memory update (default: tanh).
- Type
Callable[[…], Any]
- kernel_init#
initializer function for the kernels that transform the input (default: lecun_normal).
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- recurrent_kernel_init#
initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- bias_init#
initializer for the bias parameters (default: initializers.zeros_init())
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- dtype#
the dtype of the computation (default: None).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- __call__(carry, inputs)[source]#
Gated recurrent unit (GRU) cell.
- Parameters
carry – the hidden state of the GRU cell, initialized using
GRUCell.initialize_carry
.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
- initialize_carry(rng, input_shape)[source]#
Initialize the RNN cell carry.
- Parameters
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns
An initialized carry for the given RNN cell.
Methods
initialize_carry
(rng, input_shape)Initialize the RNN cell carry.