flax.linen.RNNCellBase#

class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RNN cell base class.

__call__(**kwargs)#

Call self as a function.

initialize_carry(rng, input_shape)[source]#

Initialize the RNN cell carry.

Parameters
  • rng – random number generator passed to the init_fn.

  • input_shape – a tuple providing the shape of the input to the cell.

Returns

An initialized carry for the given RNN cell.

Methods

initialize_carry(rng, input_shape)

Initialize the RNN cell carry.