flax.linen.make_causal_mask#

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

Parameters
  • x – input array of shape [batch…, len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len, len] shaped causal mask for 1d attention.