flax.linen.Sequential#

class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Applies a linear chain of Modules.

Meant to be used only for the simple case of fusing together callables where the input of a particular module/op is the output of the previous one.

Modules will be applied in the order that they are passed in the constructor.

The __call__ method of Sequential accepts any input and forwards it to the first module it contains. It chains the output sequentially to the input of the next module and returns the output of the final module.

Example usage:

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    return nn.Sequential([nn.Dense(4),
                          nn.relu,
                          nn.Dense(2),
                          nn.log_softmax])(x)

This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a tuple it will be expanded as *args in the next layer, if its a dict it will be expanded as **kwargs.

Example usage:

class CrossAttentionBlock(nn.Module):
  num_heads: int = 2
  qkv_features: int = 16

  @nn.compact
  def __call__(self, query, key_value):
    output = nn.MultiHeadDotProductAttention(
      num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
                                                              key_value)
    output = nn.Dense(self.qkv_features)(output)
    return dict(query=output, key_value=key_value)  # also works for tuples

class CrossAttentionNetwork(nn.Module):
  num_layers: Sequence[int]

  @nn.compact
  def __call__(self, x):
    return nn.Sequential([CrossAttentionBlock() for _ in
                          range(self.num_layers)])(query, key_value)
__call__(*args, **kwargs)[source]#

Call self as a function.

Methods