flax.linen.dot_product_attention_weights

flax.linen.dot_product_attention_weights#

flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None)[source]#

Computes dot-product attention weights given query and key.

Used by dot_product_attention(), which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.

Parameters
  • query – queries for calculating attention with shape of [batch..., q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch..., kv_length, num_heads, qk_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch..., num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs and params)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • module – the Module that will sow the attention weights into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via mutable=['intermediates'] in order to have that collection returned. If module is None, the attention weights will not be sowed.

Returns

Output of shape [batch..., num_heads, q_length, kv_length].