flax.linen.dot_product_attention¶
- flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]¶
Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.
Note: query, key, value needn’t have any batch dimensions.
- Parameters
query (Any) – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].
key (Any) – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].
value (Any) – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].
bias (Optional[Any]) – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.
mask (Optional[Any]) – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.
broadcast_dropout (bool) – bool: use a broadcasted dropout along batch dims.
dropout_rng (Optional[Any]) – JAX PRNGKey: to be used for dropout
dropout_rate (float) – dropout rate
deterministic (bool) – bool, deterministic or not (to apply dropout)
dtype (Optional[Any]) – the dtype of the computation (default: infer from inputs)
precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) – numerical precision of the computation see jax.lax.Precision for details.
- Returns
Output of shape [batch…, q_length, num_heads, v_depth_per_head].