flax.linen.make_causal_mask

flax.linen.make_causal_mask#

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch..., len], the self-attention weights will be [batch..., heads, len, len] and this function will produce a causal mask of shape [batch..., 1, len, len].

Parameters
  • x – input array of shape [batch..., len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch..., 1, len, len] shaped causal mask for 1d attention.