flax.linen.MultiHeadDotProductAttention#

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Multi-head dot-product attention.

Example usage:

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

Type

int

dtype#

the dtype of the computation (default: infer from inputs and params)

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

the dtype passed to parameter initializers (default: float32)

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

qkv_features#

dimension of the key, query, and value.

Type

Optional[int]

out_features#

dimension of the last projection

Type

Optional[int]

broadcast_dropout#

bool: use a broadcasted dropout along batch dims.

Type

bool

dropout_rate#

dropout rate

Type

float

deterministic#

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Type

Optional[bool]

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the kernel of the Dense layers.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

bias_init#

initializer for the bias of the Dense layers.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

use_bias#

bool: whether pointwise QKVO dense transforms use bias.

Type

bool

attention_fn#

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels]

Type

Callable[[…], Union[jax.Array, Any]]

decode#

whether to prepare and use an autoregressive cache.

Type

bool

normalize_qk#

should QK normalization be applied (arxiv.org/abs/2302.05442).

Type

bool

__call__(inputs_q, inputs_k=None, inputs_v=None, *, inputs_kv=None, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#

Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.

Parameters
  • inputs_q – input queries of shape [batch_sizes..., length, features].

  • inputs_k – key of shape [batch_sizes..., length, features]. If None, inputs_k will copy the value of inputs_q.

  • inputs_v – values of shape [batch_sizes..., length, features]. If None, inputs_v will copy the value of inputs_k.

  • inputs_kv – key/values of shape [batch_sizes..., length, features]. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead.

  • mask – attention mask of shape [batch_sizes..., num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

  • dropout_rng – optional rng key to pass to the attention layer’s dropout mask. Otherwise, self.make_rng(‘dropout’) is used instead.

  • sow_weights – if True, the attention weights are sowed into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via mutable=['intermediates'] in order to have that collection returned.

Returns

output of shape [batch_sizes..., length, features].

Methods