flax.linen.GRUCell¶
- class flax.linen.GRUCell(gate_fn=<CompiledFunction of <function sigmoid>>, activation_fn=<CompiledFunction of <function _one_to_one_unop.<locals>.<lambda>>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
GRU cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]where x is the input and h, is the output of the previous time step.
- Parameters
gate_fn (Callable[[...], Any]) –
activation_fn (Callable[[...], Any]) –
kernel_init (Callable[[Any, Tuple[int], Any], Any]) –
recurrent_kernel_init (Callable[[Any, Tuple[int], Any], Any]) –
bias_init (Callable[[Any, Tuple[int], Any], Any]) –
dtype (Optional[Any]) –
param_dtype (Any) –
parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
name (str) –
- Return type
None
- gate_fn¶
activation function used for gates (default: sigmoid)
- Type
Callable[[…], Any]
- activation_fn¶
activation function used for output and memory update (default: tanh).
- Type
Callable[[…], Any]
- kernel_init¶
initializer function for the kernels that transform the input (default: lecun_normal).
- Type
Callable[[Any, Tuple[int], Any], Any]
- recurrent_kernel_init¶
initializer function for the kernels that transform the hidden state (default: orthogonal).
- Type
Callable[[Any, Tuple[int], Any], Any]
- bias_init¶
initializer for the bias parameters (default: zeros)
- Type
Callable[[Any, Tuple[int], Any], Any]
- dtype¶
the dtype of the computation (default: None).
- Type
Optional[Any]
- param_dtype¶
the dtype passed to parameter initializers (default: float32).
- Type
Any
- __call__(carry, inputs)[source]¶
Gated recurrent unit (GRU) cell.
- Parameters
carry – the hidden state of the LSTM cell, initialized using GRUCell.initialize_carry.
inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns
A tuple with the new carry and the output.
Methods
Compute hyperbolic tangent element-wise.
bias_init
(shape[, dtype])An initializer that returns a constant array full of zeros.
gate_fn
()Sigmoid activation function.
initialize_carry
(rng, batch_dims, size[, ...])Initialize the RNN cell carry.
kernel_init
(shape[, dtype])recurrent_kernel_init
(shape[, dtype])