Layers#

Linear Modules#

class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=<function dot_general>, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A linear transformation applied over the last dimension of the input.

features#

the number of output features.

Type

int

use_bias#

whether to add a bias to the output (default: True).

Type

bool

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer function for the weight matrix.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer function for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)[source]#

Applies a linear transformation to the inputs along the last dimension.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods

class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=<function dot_general>, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A linear transformation with flexible axes.

features#

int or tuple with number of output features.

Type

Union[int, Sequence[int]]

axis#

int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.

Type

Union[int, Sequence[int]]

batch_dims#

tuple with batch axes.

Type

Sequence[int]

use_bias#

whether to add a bias to the output (default: True).

Type

bool

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

kernel_init#

initializer function for the weight matrix.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer function for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

__call__(inputs)[source]#

Applies a linear transformation to the inputs along multiple dimensions.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods

class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Convolution Module wrapping lax.conv_general_dilated.

features#

number of convolution filters.

Type

int

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.

Type

Sequence[int]

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

Type

Union[None, int, Sequence[int]]

padding#

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ‘CAUSAL’ padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

Type

Union[str, int, Sequence[Union[int, Tuple[int, int]]]]

input_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

Type

Union[None, int, Sequence[int]]

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

Type

Union[None, int, Sequence[int]]

feature_group_count#

integer, default 1. If specified divides the input features into groups.

Type

int

use_bias#

whether to add a bias to the output (default: True).

Type

bool

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

Type

Optional[Any]

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the convolutional kernel.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)#

Applies a (potentially unshared) convolution to the inputs.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods

class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Convolution Module wrapping lax.conv_transpose.

features#

number of convolution filters.

Type

int

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.

Type

Union[int, Sequence[int]]

strides#

a sequence of n integers, representing the inter-window strides.

Type

Optional[Sequence[int]]

padding#

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.

Type

Union[str, int, Sequence[Union[int, Tuple[int, int]]]]

kernel_dilation#

None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.

Type

Optional[Sequence[int]]

use_bias#

whether to add a bias to the output (default: True).

Type

bool

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

Type

Optional[Any]

dtype#

the dtype of the computation (default: infer from input and params).

Type

Any

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the convolutional kernel.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

transpose_kernel#

if True flips spatial axes and swaps the input/output channel axes of the kernel.

Type

bool

__call__(inputs)[source]#

Applies a transposed convolution to the inputs.

Behaviour mirrors of jax.lax.conv_transpose.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods

class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Local convolution Module wrapping lax.conv_general_dilated_local.

features#

number of convolution filters.

Type

int

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.

Type

Sequence[int]

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

Type

Union[None, int, Sequence[int]]

padding#

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ‘CAUSAL’ padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

Type

Union[str, int, Sequence[Union[int, Tuple[int, int]]]]

input_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

Type

Union[None, int, Sequence[int]]

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

Type

Union[None, int, Sequence[int]]

feature_group_count#

integer, default 1. If specified divides the input features into groups.

Type

int

use_bias#

whether to add a bias to the output (default: True).

Type

bool

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

Type

Optional[Any]

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the convolutional kernel.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)#

Applies a (potentially unshared) convolution to the inputs.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Embedding Module.

A parameterized function from integers [0, n) to d-dimensional vectors.

num_embeddings#

number of embeddings.

Type

int

features#

number of feature dimensions for each embedding.

Type

int

dtype#

the dtype of the embedding vectors (default: same as embedding).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

embedding_init#

embedding initializer.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)[source]#

Embeds the inputs along the last dimension.

Parameters

inputs – input data, all dimensions are considered batch dimensions.

Returns

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

Methods

attend(query)

Attend over the embedding using a query array.

setup()

Initializes a Module lazily (similar to a lazy __init__).

Pooling#

flax.linen.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

Pools the input by taking the maximum of a window slice.

Parameters
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).

Returns

The maximum for each window slice.

flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

Pools the input by taking the average over a window.

Parameters
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: ‘VALID’).

  • count_include_pad – a boolean whether to include padded tokens in the average calculation (default: True).

Returns

The average for each window slice.

flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#

Helper function to define pooling functions.

Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.

Parameters
  • inputs – input data with dimensions (batch, window dims…, features).

  • init – the initial value for the reduction

  • reduce_fn – a reduce function of the form (T, T) -> T.

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

Returns

The output of the reduction for each window slice.

Normalization#

class flax.linen.BatchNorm(use_running_average=None, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

BatchNorm Module.

Usage Note: If we define a model with BatchNorm, for example:

BN = nn.BatchNorm(use_running_average=False, momentum=0.9, epsilon=1e-5,
                  dtype=jnp.float32)

The initialized variables dict will contain in addition to a ‘params’ collection a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:

vars_initialized = BN.init(key, x)  # {'params': ..., 'batch_stats': ...}

We then update the batch_stats during training by specifying that the batch_stats collection is mutable in the apply method for our module.:

vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']

During eval we would define BN with use_running_average=True and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:

vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BN.apply(vars_in, x)
use_running_average#

if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Type

Optional[bool]

axis#

the feature or non-batch axis of the input.

Type

int

momentum#

decay rate for the exponential moving average of the batch statistics.

Type

float

epsilon#

a small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias#

if True, bias (beta) is added.

Type

bool

use_scale#

if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init#

initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None).

Type

Optional[str]

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

Type

bool

__call__(x, use_running_average=None)[source]#

Normalizes the input using batch statistics.

NOTE: During initialization (when self.is_initializing() is True) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set with axis_name) does not have to exist.

Parameters
  • x – the input to be normalized.

  • use_running_average – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.linen.LayerNorm(epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Layer normalization (https://arxiv.org/abs/1607.06450).

LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.

epsilon#

A small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias#

If True, bias (beta) is added.

Type

bool

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

Initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init#

Initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

reduction_axes#

Axes for computing normalization statistics.

Type

Union[int, Sequence[int]]

feature_axes#

Feature axes for learned bias and scaling.

Type

Union[int, Sequence[int]]

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

Type

Optional[str]

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

Type

bool

__call__(x)[source]#

Applies layer normalization on the input.

Parameters

x – the inputs

Returns

Normalized inputs (the same shape as inputs).

Methods

class flax.linen.GroupNorm(num_groups=32, group_size=None, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Group normalization (arxiv.org/abs/1803.08494).

This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.

num_groups#

the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.

Type

Optional[int]

group_size#

the number of channels in a group.

Type

Optional[int]

epsilon#

A small float added to variance to avoid dividing by zero.

Type

float

dtype#

the dtype of the result (default: infer from input and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

use_bias#

If True, bias (beta) is added.

Type

bool

use_scale#

If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.

Type

bool

bias_init#

Initializer for bias, by default, zero.

Type

Callable[[Any, Tuple[int, …], Any], Any]

scale_init#

Initializer for scale, by default, one.

Type

Callable[[Any, Tuple[int, …], Any], Any]

axis_name#

the axis name used to combine batch statistics from multiple devices. See jax.pmap for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.

Type

Optional[str]

axis_index_groups#

groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, [[0, 1], [2, 3]] would independently batch-normalize over the examples on the first two and last two devices. See jax.lax.psum for more details.

Type

Any

use_fast_variance#

If true, use a faster, but less numerically stable, calculation for the variance.

Type

bool

__call__(x)[source]#

Applies group normalization to the input (arxiv.org/abs/1803.08494).

Parameters

x – the input of shape N…C, where N is a batch dimension and C is a channels dimensions. represents an arbitrary number of extra dimensions that are used to accumulate statistics over.

Returns

Normalized inputs (the same shape as inputs).

Methods

Combinators#

class flax.linen.Sequential(layers, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Applies a linear chain of Modules.

Meant to be used only for the simple case of fusing together callables where the input of a particular module/op is the output of the previous one.

Modules will be applied in the order that they are passed in the constructor.

The __call__ method of Sequential accepts any input and forwards it to the first module it contains. It chains the output sequentially to the input of the next module and returns the output of the final module.

Example usage:

class Foo(nn.Module):

  @nn.compact
  def __call__(self, x):
    return nn.Sequential([nn.Dense(4),
                          nn.relu,
                          nn.Dense(2),
                          nn.log_softmax])(x)

This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a tuple it will be expanded as *args in the next layer, if its a dict it will be expanded as **kwargs.

Example usage:

class CrossAttentionBlock(nn.Module):
  num_heads: int = 2
  qkv_features: int = 16

  @nn.compact
  def __call__(self, query, key_value):
    output = nn.MultiHeadDotProductAttention(
      num_heads=self.num_heads, qkv_features=self.qkv_features)(query,
                                                              key_value)
    output = nn.Dense(self.qkv_features)(output)
    return dict(query=output, key_value=key_value)  # also works for tuples

class CrossAttentionNetwork(nn.Module):
  num_layers: Sequence[int]

  @nn.compact
  def __call__(self, x):
    return nn.Sequential([CrossAttentionBlock() for _ in
                          range(self.num_layers)])(query, key_value)
__call__(*args, **kwargs)[source]#

Call self as a function.

Methods

Stochastic#

class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Create a dropout layer.

Note: When using Module.apply(), make sure to include an RNG seed named ‘dropout’. For example:

model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout':
dropout_rng})`
rate#

the dropout probability. (_not_ the keep rate!)

Type

float

broadcast_dims#

dimensions that will share the same dropout mask

Type

Sequence[int]

deterministic#

if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

Type

Optional[bool]

rng_collection#

the rng collection name to use when requesting an rng key.

Type

str

__call__(inputs, deterministic=None, rng=None)[source]#

Applies a random dropout mask to the input.

Parameters
  • inputs – the inputs that should be randomly masked.

  • deterministic – if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

  • rng – an optional PRNGKey used as the random key, if not specified, one will be generated using make_rng with the rng_collection name.

Returns

The masked inputs reweighted to preserve mean.

Methods

Attention#

class flax.linen.SelfAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=<function dot_general>, out_dot_general=<function dot_general>, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Self-attention special case of multi-head dot-product attention.

__call__(inputs_q, mask=None, deterministic=None)[source]#

Applies multi-head dot product self-attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

Parameters
  • inputs_q – input queries of shape [batch_sizes…, length, features].

  • mask – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Returns

output of shape [batch_sizes…, length, features].

Methods

class flax.linen.MultiHeadDotProductAttention(num_heads, dtype=None, param_dtype=<class 'jax.numpy.float32'>, qkv_features=None, out_features=None, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, attention_fn=<function dot_product_attention>, decode=False, normalize_qk=False, qkv_dot_general=<function dot_general>, out_dot_general=<function dot_general>, qkv_dot_general_cls=None, out_dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Multi-head dot-product attention.

num_heads#

number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.

Type

int

dtype#

the dtype of the computation (default: infer from inputs and params)

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32)

Type

Any

qkv_features#

dimension of the key, query, and value.

Type

Optional[int]

out_features#

dimension of the last projection

Type

Optional[int]

broadcast_dropout#

bool: use a broadcasted dropout along batch dims.

Type

bool

dropout_rate#

dropout rate

Type

float

deterministic#

if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Type

Optional[bool]

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the kernel of the Dense layers.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init#

initializer for the bias of the Dense layers.

Type

Callable[[Any, Tuple[int, …], Any], Any]

use_bias#

bool: whether pointwise QKVO dense transforms use bias.

Type

bool

attention_fn#

dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`

Type

Callable[[…], Any]

decode#

whether to prepare and use an autoregressive cache.

Type

bool

normalize_qk#

should QK normalization be applied (arxiv.org/abs/2302.05442).

Type

bool

__call__(inputs_q, inputs_kv, mask=None, deterministic=None)[source]#

Applies multi-head dot product attention on the input data.

Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.

Parameters
  • inputs_q – input queries of shape [batch_sizes…, length, features].

  • inputs_kv – key/values of shape [batch_sizes…, length, features].

  • mask – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.

  • deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.

Returns

output of shape [batch_sizes…, length, features].

Methods

flax.linen.dot_product_attention_weights(query, key, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#

Computes dot-product attention weights given query and key.

Used by dot_product_attention(), which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.

Parameters
  • query – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs and params)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

Returns

Output of shape [batch…, num_heads, q_length, kv_length].

flax.linen.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None)[source]#

Computes dot-product attention given query, key, and value.

This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.

Note: query, key, value needn’t have any batch dimensions.

Parameters
  • query – queries for calculating attention with shape of [batch…, q_length, num_heads, qk_depth_per_head].

  • key – keys for calculating attention with shape of [batch…, kv_length, num_heads, qk_depth_per_head].

  • value – values to be used in attention with shape of [batch…, kv_length, num_heads, v_depth_per_head].

  • bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.

  • mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.

  • broadcast_dropout – bool: use a broadcasted dropout along batch dims.

  • dropout_rng – JAX PRNGKey: to be used for dropout

  • dropout_rate – dropout rate

  • deterministic – bool, deterministic or not (to apply dropout)

  • dtype – the dtype of the computation (default: infer from inputs)

  • precision – numerical precision of the computation see jax.lax.Precision for details.

Returns

Output of shape [batch…, q_length, num_heads, v_depth_per_head].

flax.linen.make_attention_mask(query_input, key_input, pairwise_fn=<PjitFunction of <function jax.numpy.multiply>>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Mask-making helper for attention weights.

In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].

Parameters
  • query_input – a batched, flat input of query_length size

  • key_input – a batched, flat input of key_length size

  • pairwise_fn – broadcasting elementwise comparison function

  • extra_batch_dims – number of extra batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.

flax.linen.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#

Make a causal mask for self-attention.

In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].

Parameters
  • x – input array of shape [batch…, len]

  • extra_batch_dims – number of batch dims to add singleton axes for, none by default

  • dtype – mask return dtype

Returns

A [batch…, 1, len, len] shaped causal mask for 1d attention.

Recurrent#

class flax.linen.RNNCellBase(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

RNN cell base class.

__call__(**kwargs)#

Call self as a function.

Methods

initialize_carry(rng, input_shape)

Initialize the RNN cell carry.

class flax.linen.LSTMCell(*args, **kwds)[source]#

LSTM cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

features#

number of output features.

Type

int

gate_fn#

activation function used for gates (default: sigmoid)

Type

Callable[[…], Any]

activation_fn#

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

Type

jax.nn.initializers.Initializer

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).

Type

jax.nn.initializers.Initializer

bias_init#

initializer for the bias parameters (default: initializers.zeros_init())

Type

jax.nn.initializers.Initializer

dtype#

the dtype of the computation (default: infer from inputs and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]#

A long short-term memory (LSTM) cell.

Parameters
  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

initialize_carry(**kwargs)

Initialize the RNN cell carry.

class flax.linen.OptimizedLSTMCell(*args, **kwds)[source]#

More efficient LSTM Cell that concatenates state components before matmul.

The parameters are compatible with LSTMCell. Note that this cell is often faster than LSTMCell as long as the hidden size is roughly <= 2048 units.

The mathematical definition of the cell is the same as LSTMCell and as follows

\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]

where x is the input, h is the output of the previous time step, and c is the memory.

gate_fn#

activation function used for gates (default: sigmoid).

Type

Callable[[…], Any]

activation_fn#

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

Type

jax.nn.initializers.Initializer

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).

Type

jax.nn.initializers.Initializer

bias_init#

initializer for the bias parameters (default: initializers.zeros_init()).

Type

jax.nn.initializers.Initializer

dtype#

the dtype of the computation (default: infer from inputs and params).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]#

An optimized long short-term memory (LSTM) cell.

Parameters
  • carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

initialize_carry(**kwargs)

Initialize the RNN cell carry.

class flax.linen.GRUCell(*args, **kwds)[source]#

GRU cell.

The mathematical definition of the cell is as follows

\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + W_{hr} h + b_{hr}) \\ z = \sigma(W_{iz} x + W_{hz} h + b_{hz}) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]

where x is the input and h, is the output of the previous time step.

gate_fn#

activation function used for gates (default: sigmoid)

Type

Callable[[…], Any]

activation_fn#

activation function used for output and memory update (default: tanh).

Type

Callable[[…], Any]

kernel_init#

initializer function for the kernels that transform the input (default: lecun_normal).

Type

jax.nn.initializers.Initializer

recurrent_kernel_init#

initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).

Type

jax.nn.initializers.Initializer

bias_init#

initializer for the bias parameters (default: initializers.zeros_init())

Type

jax.nn.initializers.Initializer

dtype#

the dtype of the computation (default: None).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

__call__(carry, inputs)[source]#

Gated recurrent unit (GRU) cell.

Parameters
  • carry – the hidden state of the GRU cell, initialized using GRUCell.initialize_carry.

  • inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.

Returns

A tuple with the new carry and the output.

Methods

initialize_carry(**kwargs)

Initialize the RNN cell carry.

class flax.linen.RNN(cell, cell_size=<flax.linen.recurrent._Never object>, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, variable_axes=FrozenDict({}), variable_broadcast='params', variable_carry=False, split_rngs=FrozenDict({     params: False, }), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

The RNN module takes any RNNCellBase instance and applies it over a sequence using flax.linen.scan().

Example:

>>> import jax.numpy as jnp
>>> import jax
>>> import flax.linen as nn
...
>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64))
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

As shown above, RNN uses the cell_size argument to set the size argument for the cell’s initialize_carry method, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example the ConvLSTMCell requires a size argument of the form (kernel_height, kernel_width, features):

>>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features)
>>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3)))
>>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x)
>>> y.shape # (batch, time, height, width, features)
(10, 50, 32, 32, 64)

By default RNN expect the time dimension after the batch dimension ((*batch, time, *features)), if you set time_major=True RNN will instead expect the time dimesion to be at the beginning ((time, *batch, *features)):

>>> x = jnp.ones((50, 10, 32)) # (time, batch, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> y = lstm.apply(variables, x)
>>> y.shape # (time, batch, cell_size)
(50, 10, 64)

The output is an array of shape (*batch, time, *cell_size) by default (typically), however if you set return_carry=True it will instead return a tuple of the final carry and the output:

>>> x = jnp.ones((10, 50, 32)) # (batch, time, features)
>>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True)
>>> variables = lstm.init(jax.random.key(0), x)
>>> carry, y = lstm.apply(variables, x)
>>> jax.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size))
((10, 64), (10, 64))
>>> y.shape # (batch, time, cell_size)
(10, 50, 64)

To support variable length sequences, you can pass a seq_lengths which is an integer array of shape (*batch) where each element is the length of the sequence in the batch. For example:

>>> seq_lengths = jnp.array([3, 2, 5])

The output elements corresponding to padding elements are NOT zeroed out. If return_carry is set to True the carry will be the state of the last valid element of each sequence.

RNN also accepts some of the arguments of flax.linen.scan(), by default they are set to work with cells like LSTMCell and GRUCell but they can be overriden as needed. Overriding default values to scan looks like this:

>>> lstm = nn.RNN(
...   nn.LSTMCell(64),
...   unroll=1, variable_axes={}, variable_broadcast='params',
...   variable_carry=False, split_rngs={'params': False})
cell#

an instance of RNNCellBase.

Type

flax.linen.recurrent.RNNCellBase

time_major#

if time_major=False (default) it will expect inputs with shape (*batch, time, *features), else it will expect inputs with shape (time, *batch, *features).

Type

bool

return_carry#

if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

Type

bool

reverse#

if reverse=False (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If seq_lengths is passed, padding will always remain at the end of the sequence.

Type

bool

keep_order#

if keep_order=True, when reverse=True the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If keep_order=False (default), the output will remain in the order specified by reverse.

Type

bool

unroll#

how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to nn.scan.

Type

int

variable_axes#

a dictionary mapping each collection to either an integer i (meaning we scan over dimension i) or None (replicate rather than scan). This argument is forwarded to nn.scan.

Type

Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.core.lift.In[int], flax.core.lift.Out[int]]]

variable_broadcast#

Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to nn.scan.

Type

Union[bool, str, Collection[str], DenyList]

variable_carry#

Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to nn.scan.

Type

Union[bool, str, Collection[str], DenyList]

split_rngs#

a mapping from PRNGSequenceFilter to bool specifying whether a collection’s PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to nn.scan.

Type

Mapping[Union[bool, str, Collection[str], DenyList], bool]

__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

Applies the RNN to the inputs.

__call__ allows you to optionally override some attributes like return_carry and time_major defined in the constructor.

Parameters
  • inputs – the input sequence.

  • initial_carry – the initial carry, if not provided it will be initialized using the cell’s RNNCellBase.initialize_carry() method.

  • init_key – a PRNG key used to initialize the carry, if not provided jax.random.key(0) will be used. Most cells will ignore this argument.

  • seq_lengths – an optional integer array of shape (*batch) indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored.

  • return_carry – if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

  • time_major – if time_major=False (default) it will expect inputs with shape (*batch, time, *features), else it will expect inputs with shape (time, *batch, *features).

  • reverse – overrides the reverse attribute, if reverse=False (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If seq_lengths is passed, padding will always remain at the end of the sequence.

  • keep_order – overrides the keep_order attribute, if keep_order=True, when reverse=True the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If keep_order=False (default), the output will remain in the order specified by reverse.

Returns

if return_carry=False (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.

Methods

class flax.linen.Bidirectional(forward_rnn, backward_rnn, merge_fn=<function _concatenate>, time_major=False, return_carry=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Processes the input in both directions and merges the results.

__call__(inputs, *, initial_carry=None, init_key=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#

Call self as a function.

Methods

Summary

Dense(features[, use_bias, dtype, ...])

A linear transformation applied over the last dimension of the input.

DenseGeneral(features[, axis, batch_dims, ...])

A linear transformation with flexible axes.

Conv(features, kernel_size[, strides, ...])

Convolution Module wrapping lax.conv_general_dilated.

ConvTranspose(features, kernel_size[, ...])

Convolution Module wrapping lax.conv_transpose.

ConvLocal(features, kernel_size[, strides, ...])

Local convolution Module wrapping lax.conv_general_dilated_local.

Embed(num_embeddings, features[, dtype, ...])

Embedding Module.

BatchNorm([use_running_average, axis, ...])

BatchNorm Module.

LayerNorm([epsilon, dtype, param_dtype, ...])

Layer normalization (https://arxiv.org/abs/1607.06450).

GroupNorm([num_groups, group_size, epsilon, ...])

Group normalization (arxiv.org/abs/1803.08494).

RMSNorm([epsilon, dtype, param_dtype, ...])

RMS Layer normalization (https://arxiv.org/abs/1910.07467).

Sequential(layers[, parent, name])

Applies a linear chain of Modules.

Dropout(rate[, broadcast_dims, ...])

Create a dropout layer.

SelfAttention(num_heads[, dtype, ...])

Self-attention special case of multi-head dot-product attention.

MultiHeadDotProductAttention(num_heads[, ...])

Multi-head dot-product attention.

RNNCellBase([parent, name])

RNN cell base class.

LSTMCell(*args, **kwds)

LSTM cell.

OptimizedLSTMCell(*args, **kwds)

More efficient LSTM Cell that concatenates state components before matmul.

GRUCell(*args, **kwds)

GRU cell.

RNN(cell[, cell_size, time_major, ...])

The RNN module takes any RNNCellBase instance and applies it over a sequence using flax.linen.scan().

Bidirectional(forward_rnn, backward_rnn[, ...])

Processes the input in both directions and merges the results.

max_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the maximum of a window slice.

avg_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the average over a window.

pool(inputs, init, reduce_fn, window_shape, ...)

Helper function to define pooling functions.

dot_product_attention_weights(query, key[, ...])

Computes dot-product attention weights given query and key.

dot_product_attention(query, key, value[, ...])

Computes dot-product attention given query, key, and value.

make_attention_mask(query_input, key_input)

Mask-making helper for attention weights.

make_causal_mask(x[, extra_batch_dims, dtype])

Make a causal mask for self-attention.