flax.linen.Bidirectional#
- class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Processes the input in both directions and merges the results.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Bidirectional(nn.RNN(nn.GRUCell(4)), nn.RNN(nn.GRUCell(4))) >>> x = jnp.ones((2, 3)) >>> variables = layer.init(jax.random.key(0), x) >>> out = layer.apply(variables, x)
- __call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Call self as a function.
Methods