flax.linen.GRUCell#

class flax.linen.GRUCell(*args, **kwds)[source]#

GRU cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

where x is the input and h, is the output of the previous time step.

gate_fn#

activation function used for gates (default: sigmoid)

Type

Callable[[…], Any]

activation_fn#

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

Type

jax.nn.initializers.Initializer

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).

Type

jax.nn.initializers.Initializer

bias_init#

initializer for the bias parameters (default: initializers.zeros_init())

Type

jax.nn.initializers.Initializer

dtype#

the dtype of the computation (default: None).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]#

Gated recurrent unit (GRU) cell.

Parameters
  • carry – the hidden state of the GRU cell, initialized using GRUCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

initialize_carry(**kwargs)

Initialize the RNN cell carry.