Attention#

class flax.nnx.MultiHeadAttention(*args, **kwargs)[source]#

Multi-head attention.

Example usage:

>>> from flax import nnx
>>> import jax

>>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16,
...                                decode=False, rngs=nnx.Rngs(0))
>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> shape = (4, 3, 2, 5)
>>> q, k, v = (
...   jax.random.uniform(key1, shape),
...   jax.random.uniform(key2, shape),
...   jax.random.uniform(key3, shape),
... )

>>> # different inputs for inputs_q, inputs_k and inputs_v
>>> out = layer(q, k, v)
>>> # equivalent output when inferring v
>>> assert (layer(q, k) == layer(q, k, k)).all()
>>> # equivalent output when inferring k and v
>>> assert (layer(q) == layer(q, q)).all()
>>> assert (layer(q) == layer(q, q, q)).all()
num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

in_features#

int or tuple with number of input features.

qkv_features#

dimension of the key, query, and value.

out_features#

dimension of the last projection

dtype#

the dtype of the computation (default: infer from inputs and params)

param_dtype#

the dtype passed to parameter initializers (default: float32)

broadcast_dropout#

bool: use a broadcasted dropout along batch dims.

dropout_rate#

dropout rate

deterministic#

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the kernel of the Dense layers.

out_kernel_init#

optional initializer for the kernel of the output Dense layer, if None, the kernel_init is used.

bias_init#

initializer for the bias of the Dense layers.

out_bias_init#

optional initializer for the bias of the output Dense layer, if None, the bias_init is used.

use_bias#

bool: whether pointwise QKVO dense transforms use bias.

attention_fn#

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`

decode#

whether to prepare and use an autoregressive cache.

normalize_qk#

should QK normalization be applied (arxiv.org/abs/2302.05442).

rngs#

rng key.

__call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[source]#

Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.

Parameters
  • inputs_q – input queries of shape [batch_sizes…, length, features].

  • inputs_k – key of shape [batch_sizes…, length, features]. If None, inputs_k will copy the value of inputs_q.

  • inputs_v – values of shape [batch_sizes…, length, features]. If None, inputs_v will copy the value of inputs_k.

  • mask – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. The deterministic flag passed into the call method will take precedence over the deterministic flag passed into the constructor.

  • rngs – rng key. The rng key passed into the call method will take precedence over the rng key passed into the constructor.

  • sow_weights – if True, the attention weights are sowed into the ‘intermediates’ collection.

  • decode – whether to prepare and use an autoregressive cache. The decode flag passed into the call method will take precedence over the decode flag passed into the constructor.

Returns

output of shape [batch_sizes…, length, features].

init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#

Initializes cache for fast autoregressive decoding. When decode=True, this method must be called first before performing forward inference. When in decode mode, only one token must be passed at a time.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> batch_size = 5
>>> embed_dim = 3
>>> x = jnp.ones((batch_size, 1, embed_dim)) # single token
...
>>> model_nnx = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(42),
... )
...
>>> # out_nnx = model_nnx(x)  <-- throws an error because cache isn't initialized
...
>>> model_nnx.init_cache(x.shape)
>>> out_nnx = model_nnx(x)

Methods

init_cache(input_shape[, dtype])

Initializes cache for fast autoregressive decoding.

flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#

Combine attention masks.

Parameters
  • *masks – set of attention mask arguments to combine, some can be None.

  • dtype – dtype for the returned mask.

Returns

Combined mask, reduced by logical and, returns None if no masks given.

flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#

Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.

Note

query, key, value needn’t have any batch dimensions.

Parameters
  • query – queries for calculating attention with shape of [batch..., q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch..., kv_length, num_heads, qk_depth_per_head].

  • value – values to be used in attention with shape of [batch..., kv_length, num_heads, v_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • module – the Module that will sow the attention weights into the nnx.Intermediate collection. If module is None, the attention weights will not be sowed.

Returns

Output of shape [batch…, q_length, num_heads, v_depth_per_head].

flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

Parameters
  • query_input – a batched, flat input of query_length size

  • key_input – a batched, flat input of key_length size

  • pairwise_fn – broadcasting elementwise comparison function

  • extra_batch_dims – number of extra batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.

flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

Parameters
  • x – input array of shape [batch…, len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len, len] shaped causal mask for 1d attention.