flax.linen.MultiHeadDotProductAttention

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Multi-head dot-product attention.

Parameters
  • num_heads (int) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • qkv_features (Optional[int]) –

  • out_features (Optional[int]) –

  • broadcast_dropout (bool) –

  • dropout_rate (float) –

  • deterministic (Optional[bool]) –

  • precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –

  • kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int], Any], Any]) –

  • use_bias (bool) –

  • attention_fn (Callable[[Any, Any, Any], Any]) –

  • decode (bool) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

num_heads

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

Type

int

dtype

the dtype of the computation (default: infer from inputs and params)

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32)

Type

Any

qkv_features

dimension of the key, query, and value.

Type

Optional[int]

out_features

dimension of the last projection

Type

Optional[int]

broadcast_dropout

bool: use a broadcasted dropout along batch dims.

Type

bool

dropout_rate

dropout rate

Type

float

deterministic

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Type

Optional[bool]

precision

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init

initializer for the kernel of the Dense layers.

Type

Callable[[Any, Tuple[int], Any], Any]

bias_init

initializer for the bias of the Dense layers.

Type

Callable[[Any, Tuple[int], Any], Any]

use_bias

bool: whether pointwise QKVO dense transforms use bias.

Type

bool

attention_fn

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`

Type

Callable[[Any, Any, Any], Any]

decode

whether to prepare and use an autoregressive cache.

Type

bool

__call__(inputs_q, inputs_kv, mask=None, deterministic=None)[source]

Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

Parameters
  • inputs_q (Any) – input queries of shape [batch_sizes…, length, features].

  • inputs_kv (Any) – key/values of shape [batch_sizes…, length, features].

  • mask (Optional[Any]) – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic (Optional[bool]) – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Returns

output of shape [batch_sizes…, length, features].

Methods

attention_fn(key, value[, bias, mask, ...])

Computes dot-product attention given query, key, and value.

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

kernel_init(shape[, dtype])