flax.linen.Bidirectional#

class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Processes the input in both directions and merges the results.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4)))
>>> x = jnp.ones((2, 3))
>>> variables = layer.init(jax.random.key(0), x)
>>> out = layer.apply(variables, x)
__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

Call self as a function.

Methods