flax.linen.make_attention_mask

flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<CompiledFunction of <function _maybe_bool_binop.<locals>.fn>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

Parameters
  • query_input (Any) – a batched, flat input of query_length size

  • key_input (Any) – a batched, flat input of key_length size

  • pairwise_fn (Callable[[...], Any]) – broadcasting elementwise comparison function

  • extra_batch_dims (int) – number of extra batch dims to add singleton axes for, none by default

  • dtype (Any) – mask return dtype

Returns

A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.