flax.linen.make_attention_mask#
- flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e.,
[batch..., len_q]
,[batch..., len_kv]
, the attention weights will be[batch..., heads, len_q, len_kv]
and this function will produce[batch..., 1, len_q, len_kv]
.- Parameters
query_input – a batched, flat input of query_length size
key_input – a batched, flat input of key_length size
pairwise_fn – broadcasting elementwise comparison function
extra_batch_dims – number of extra batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns
A
[batch..., 1, len_q, len_kv]
shaped mask for 1d attention.