flax.linen.MultiHeadDotProductAttention¶
- class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Multi-head dot-product attention.
- Parameters
num_heads (int) –
dtype (Optional[Any]) –
param_dtype (Any) –
qkv_features (Optional[int]) –
out_features (Optional[int]) –
broadcast_dropout (bool) –
dropout_rate (float) –
deterministic (Optional[bool]) –
precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –
kernel_init (Callable[[Any, Tuple[int], Any], Any]) –
bias_init (Callable[[Any, Tuple[int], Any], Any]) –
use_bias (bool) –
attention_fn (Callable[[Any, Any, Any], Any]) –
decode (bool) –
parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
name (str) –
- Return type
None
- num_heads¶
number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
- Type
int
- dtype¶
the dtype of the computation (default: infer from inputs and params)
- Type
Optional[Any]
- param_dtype¶
the dtype passed to parameter initializers (default: float32)
- Type
Any
- qkv_features¶
dimension of the key, query, and value.
- Type
Optional[int]
- out_features¶
dimension of the last projection
- Type
Optional[int]
- broadcast_dropout¶
bool: use a broadcasted dropout along batch dims.
- Type
bool
- dropout_rate¶
dropout rate
- Type
float
- deterministic¶
if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- Type
Optional[bool]
- precision¶
numerical precision of the computation see jax.lax.Precision for details.
- Type
Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init¶
initializer for the kernel of the Dense layers.
- Type
Callable[[Any, Tuple[int], Any], Any]
- bias_init¶
initializer for the bias of the Dense layers.
- Type
Callable[[Any, Tuple[int], Any], Any]
- use_bias¶
bool: whether pointwise QKVO dense transforms use bias.
- Type
bool
- attention_fn¶
dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`
- Type
Callable[[Any, Any, Any], Any]
- decode¶
whether to prepare and use an autoregressive cache.
- Type
bool
- __call__(inputs_q, inputs_kv, mask=None, deterministic=None)[source]¶
Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
- Parameters
inputs_q (Any) – input queries of shape [batch_sizes…, length, features].
inputs_kv (Any) – key/values of shape [batch_sizes…, length, features].
mask (Optional[Any]) – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.
deterministic (Optional[bool]) – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- Returns
output of shape [batch_sizes…, length, features].
Methods
attention_fn
(key, value[, bias, mask, ...])Computes dot-product attention given query, key, and value.
bias_init
(shape[, dtype])An initializer that returns a constant array full of zeros.
kernel_init
(shape[, dtype])