Attention#
- class flax.experimental.nnx.MultiHeadAttention(*args, **kwargs)[source]#
Multi-head attention.
Example usage:
>>> import flax.linen as nn >>> import jax >>> layer = nn.MultiHeadAttention(num_heads=8, qkv_features=16) >>> key1, key2, key3, key4, key5, key6 = jax.random.split(jax.random.key(0), 6) >>> shape = (4, 3, 2, 5) >>> q, k, v = jax.random.uniform(key1, shape), jax.random.uniform(key2, shape), jax.random.uniform(key3, shape) >>> variables = layer.init(jax.random.key(0), q) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer.apply(variables, q, k, v) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k) >>> out = layer.apply(variables, q, k) >>> # equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q) >>> out = layer.apply(variables, q) >>> attention_kwargs = dict( ... num_heads=8, ... qkv_features=16, ... kernel_init=nn.initializers.ones, ... bias_init=nn.initializers.zeros, ... dropout_rate=0.5, ... deterministic=False, ... ) >>> class Module(nn.Module): ... attention_kwargs: dict ... ... @nn.compact ... def __call__(self, x, dropout_rng=None): ... out1 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... out2 = nn.MultiHeadAttention(**self.attention_kwargs)(x, dropout_rng=dropout_rng) ... return out1, out2 >>> module = Module(attention_kwargs) >>> variables = module.init({'params': key1, 'dropout': key2}, q) >>> # out1 and out2 are different. >>> out1, out2 = module.apply(variables, q, rngs={'dropout': key3}) >>> # out3 and out4 are different. >>> # out1 and out3 are different. out2 and out4 are different. >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key4}) >>> # out1 and out2 are the same. >>> out1, out2 = module.apply(variables, q, dropout_rng=key5) >>> # out1 and out2 are the same as out3 and out4. >>> # providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply` >>> out3, out4 = module.apply(variables, q, rngs={'dropout': key6}, dropout_rng=key5)
- num_heads#
number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- dtype#
the dtype of the computation (default: infer from inputs and params)
- param_dtype#
the dtype passed to parameter initializers (default: float32)
- qkv_features#
dimension of the key, query, and value.
- out_features#
dimension of the last projection
- broadcast_dropout#
bool: use a broadcasted dropout along batch dims.
- dropout_rate#
dropout rate
- deterministic#
if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- kernel_init#
initializer for the kernel of the Dense layers.
- bias_init#
initializer for the bias of the Dense layers.
- use_bias#
bool: whether pointwise QKVO dense transforms use bias.
- attention_fn#
dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`
- decode#
whether to prepare and use an autoregressive cache.
- normalize_qk#
should QK normalization be applied (arxiv.org/abs/2302.05442).
- init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#
Initializes cache for fast autoregressive decoding. When
decode=True
, this method must be called first before performing forward inference.Example usage:
>>> from flax.experimental import nnx >>> import jax.numpy as jnp ... >>> rngs = nnx.Rngs(42) ... >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=rngs, ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x)
- flax.experimental.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#
Combine attention masks.
- Parameters
*masks – set of attention mask arguments to combine, some can be None.
dtype – dtype for the returned mask.
- Returns
Combined mask, reduced by logical and, returns None if no masks given.
- flax.experimental.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#
Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.
Note
query
,key
,value
needn’t have any batch dimensions.- Parameters
query – queries for calculating attention with shape of
[batch..., q_length, num_heads, qk_depth_per_head]
.key – keys for calculating attention with shape of
[batch..., kv_length, num_heads, qk_depth_per_head]
.value – values to be used in attention with shape of
[batch..., kv_length, num_heads, v_depth_per_head]
.bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.
mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.
broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rng – JAX PRNGKey: to be used for dropout
dropout_rate – dropout rate
deterministic – bool, deterministic or not (to apply dropout)
dtype – the dtype of the computation (default: infer from inputs)
precision – numerical precision of the computation see jax.lax.Precision for details.
module – the Module that will sow the attention weights into the
nnx.Intermediate
collection. Ifmodule
is None, the attention weights will not be sowed.
- Returns
Output of shape [batch…, q_length, num_heads, v_depth_per_head].
- flax.experimental.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].
- Parameters
query_input – a batched, flat input of query_length size
key_input – a batched, flat input of key_length size
pairwise_fn – broadcasting elementwise comparison function
extra_batch_dims – number of extra batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns
A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.
- flax.experimental.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Make a causal mask for self-attention.
In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].
- Parameters
x – input array of shape [batch…, len]
extra_batch_dims – number of batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns
A [batch…, 1, len, len] shaped causal mask for 1d attention.