
class flax.experimental.nnx.MultiHeadAttention(*args, **kwargs)[source]#

Multi-head attention.

Example usage:

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.


the dtype of the computation (default: infer from inputs and params)


the dtype passed to parameter initializers (default: float32)


dimension of the key, query, and value.


dimension of the last projection


bool: use a broadcasted dropout along batch dims.


dropout rate


if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.


numerical precision of the computation see jax.lax.Precision for details.


initializer for the kernel of the Dense layers.


initializer for the bias of the Dense layers.


bool: whether pointwise QKVO dense transforms use bias.


dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`


whether to prepare and use an autoregressive cache.


should QK normalization be applied (

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#

Initializes cache for fast autoregressive decoding. When decode=True, this method must be called first before performing forward inference.

Example usage:

>>> from flax.experimental import nnx
>>> import jax.numpy as jnp
>>> rngs = nnx.Rngs(42)
>>> x = jnp.ones((1, 3))
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=rngs,
... )
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)
flax.experimental.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#

Combine attention masks.

  • *masks – set of attention mask arguments to combine, some can be None.

  • dtype – dtype for the returned mask.


Combined mask, reduced by logical and, returns None if no masks given.

flax.experimental.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#

Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on It calculates the attention weights given query and key and combines the values using the attention weights.


query, key, value needn’t have any batch dimensions.

  • query – queries for calculating attention with shape of [batch..., q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch..., kv_length, num_heads, qk_depth_per_head].

  • value – values to be used in attention with shape of [batch..., kv_length, num_heads, v_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • module – the Module that will sow the attention weights into the nnx.Intermediate collection. If module is None, the attention weights will not be sowed.


Output of shape [batch…, q_length, num_heads, v_depth_per_head].

flax.experimental.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

  • query_input – a batched, flat input of query_length size

  • key_input – a batched, flat input of key_length size

  • pairwise_fn – broadcasting elementwise comparison function

  • extra_batch_dims – number of extra batch dims to add singleton axes for, none by default

  • dtype – mask return dtype


A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.

flax.experimental.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

  • x – input array of shape [batch…, len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype


A [batch…, 1, len, len] shaped causal mask for 1d attention.