flax.linen.SelfAttention#
- class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of
MultiHeadDotProductAttention
.- Example usage::
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.MultiHeadDotProductAttention(num_heads=8, qkv_features=16) >>> variables = layer.init(jax.random.key(0), jnp.ones((4, 3, 2, 5)))
- __call__(inputs_q, mask=None, deterministic=None, dropout_rng=None, sow_weights=False)[source]#
Applies multi-head dot product self-attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
- Parameters
inputs_q – input queries of shape
[batch_sizes..., length, features]
.mask – attention mask of shape
[batch_sizes..., num_heads, query_length, key/value_length]
. Attention weights are masked out if their corresponding mask value isFalse
.deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
- Returns
output of shape
[batch_sizes..., length, features]
.
Methods