Linear#

NNX linear layer classes.

class flax.experimental.nnx.Conv(*args, **kwargs)[source]#

Convolution Module wrapping lax.conv_general_dilated[_local].

features#

number of convolution filters.

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

padding#

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides. ‘CAUSAL’ padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

input_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

feature_group_count#

integer, default 1. If specified divides the input features into groups.

use_bias#

whether to add a bias to the output (default: True).

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the convolutional kernel.

bias_init#

initializer for the bias.

class flax.experimental.nnx.Embed(*args, **kwargs)[source]#

Embedding Module.

A parameterized function from integers [0, num_embeddings) to features-dimensional vectors. This Module will create an embedding matrix with shape (num_embeddings, features). When calling this layer, the input values will be used to 0-index into the embedding matrix. Indexing on a value greater than or equal to num_embeddings will result in nan values. When num_embeddings equals to 1, it will broadcast the embedding matrix to input shape with features dimension appended.

num_embeddings#

number of embeddings / vocab size.

features#

number of feature dimensions for each embedding.

dtype#

the dtype of the embedding vectors (default: same as embedding).

param_dtype#

the dtype passed to parameter initializers (default: float32).

embedding_init#

embedding initializer.

attend(query)[source]#

Attend over the embedding using a query array.

Parameters

query – array with last dimension equal the feature depth features of the embedding.

Returns

An array with final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.

class flax.experimental.nnx.Linear(*args, **kwargs)[source]#

A linear transformation applied over the last dimension of the input.

features#

the number of output features.

use_bias#

whether to add a bias to the output (default: True).

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

class flax.experimental.nnx.LinearGeneral(*args, **kwargs)[source]#

A linear transformation with flexible axes.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # equivalent to `nn.Linear(features=4)`
>>> layer = nn.LinearGeneral(features=4)
>>> # output features (4, 5)
>>> layer = nn.LinearGeneral(features=(4, 5))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}
>>> # apply transformation on the the second and last axes
>>> layer = nn.LinearGeneral(features=(4, 5), axis=(1, -1))
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}
features#

int or tuple with number of output features.

axis#

int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.

batch_dims#

tuple with batch axes.

use_bias#

whether to add a bias to the output (default: True).

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

precision#

numerical precision of the computation see jax.lax.Precision for details.

class flax.experimental.nnx.Einsum(*args, **kwargs)[source]#

An einsum transformation with learnable kernel and bias.

Example usage:

>>> from flax.experimental import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Einsum('abc,cde->abde', (3, 4, 5), (5, 6, 7), rngs=nnx.Rngs(0))
>>> assert layer.kernel.value.shape == (5, 6, 7)
>>> assert layer.bias.value.shape == (6, 7)
>>> out = layer(jnp.ones((3, 4, 5)))
>>> assert out.shape == (3, 4, 6, 7)
einsum_str#

a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of einsum_str in the constructor argument and call argument must be not None, while the other must be None.

kernel_shape#

the shape of the kernel.

bias_shape#

the shape of the bias. If this is None, a bias won’t be used.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

rngs#

rng key.