Attention#

class flax.experimental.nnx.MultiHeadAttention(*args, **kwargs)[source]#

Multi-head attention.

Example usage:

>>> import flax.linen as nn
>>> import jax

>>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16)
>>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape)
>>> variables = layer.init(jax.random.key(0), q)

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer.apply(variables, q, k, v)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)
>>> out = layer.apply(variables, q, k)
>>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)
>>> out = layer.apply(variables, q)

>>> attention_kwargs = dict(
...     num_heads=8,
...     qkv_features=16,
...     kernel_init=nn.initializers.ones,
...     bias_init=nn.initializers.zeros,
...     dropout_rate=0.5,
...     deterministic=False,
...     )
>>> class Module(nn.Module):
...   attention_kwargs: dict
...
...   @nn.compact
...   def __call__(self, x, dropout_rng=None):
...     out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng)
...     return out1, out2
>>> module = Module(attention_kwargs)
>>> variables = module.init({'params': key1, 'dropout': key2}, q)

>>> # out1 and out2 are different.
>>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3})
>>> # out3 and out4 are different.
>>> # out1 and out3 are different. out2 and out4 are different.
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4})
>>> # out1 and out2 are the same.
>>> out1, out2 = module.apply(variables, q, dropout_rng=key5)
>>> # out1 and out2 are the same as out3 and out4.
>>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`
>>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

dtype#

the dtype of the computation (default: infer from inputs and params)

param_dtype#

the dtype passed to parameter initializers (default: float32)

qkv_features#

dimension of the key, query, and value.

out_features#

dimension of the last projection

broadcast_dropout#

bool: use a broadcasted dropout along batch dims.

dropout_rate#

dropout rate

deterministic#

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the kernel of the Dense layers.

bias_init#

initializer for the bias of the Dense layers.

use_bias#

bool: whether pointwise QKVO dense transforms use bias.

attention_fn#

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`

decode#

whether to prepare and use an autoregressive cache.

normalize_qk#

should QK normalization be applied (arxiv.org/abs/2302.05442).

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#

Initializes cache for fast autoregressive decoding. When decode=True, this method must be called first before performing forward inference.

Example usage:

>>> from flax.experimental import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(42)

>>> x = jnp.ones((1, 3))
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=rngs,
>>> )

>>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized

>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)
flax.experimental.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#

Combine attention masks.

Parameters
  • *masks – set of attention mask arguments to combine, some can be None.

  • dtype – dtype for the returned mask.

Returns

Combined mask, reduced by logical and, returns None if no masks given.

flax.experimental.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#

Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.

Note

query, key, value needn’t have any batch dimensions.

Parameters
  • query – queries for calculating attention with shape of [batch..., q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch..., kv_length, num_heads, qk_depth_per_head].

  • value – values to be used in attention with shape of [batch..., kv_length, num_heads, v_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • module – the Module that will sow the attention weights into the nnx.Intermediate collection. If module is None, the attention weights will not be sowed.

Returns

Output of shape [batch…, q_length, num_heads, v_depth_per_head].

flax.experimental.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

Parameters
  • query_input – a batched, flat input of query_length size

  • key_input – a batched, flat input of key_length size

  • pairwise_fn – broadcasting elementwise comparison function

  • extra_batch_dims – number of extra batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.

flax.experimental.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

Parameters
  • x – input array of shape [batch…, len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len, len] shaped causal mask for 1d attention.