flax.linen.SelfAttention

flax.linen.SelfAttention#

class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of MultiHeadDotProductAttention.

Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16)
>>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
__call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#

Applies multi-head dot product self-attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

Parameters
  • inputs_q – input queries of shape [batch_sizes..., length, features].

  • mask – attention mask of shape [batch_sizes..., num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Returns

output of shape [batch_sizes..., length, features].

Methods