class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Self-attention special case of multi-head dot-product attention.

  • num_heads (int) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • qkv_features (Optional[int]) –

  • out_features (Optional[int]) –

  • broadcast_dropout (bool) –

  • dropout_rate (float) –

  • deterministic (Optional[bool]) –

  • precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –

  • kernel_init (Callable[[Any, Tuple[int], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int], Any], Any]) –

  • use_bias (bool) –

  • attention_fn (Callable[[Any, Any, Any], Any]) –

  • decode (bool) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type


__call__(inputs_q, mask=None, deterministic=None)[source]

Applies multi-head dot product self-attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

  • inputs_q (Any) – input queries of shape [batch_sizes…, length, features].

  • mask (Optional[Any]) – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic (Optional[bool]) – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.


output of shape [batch_sizes…, length, features].