Linear

Linear#

NNX linear layer classes.

class flax.nnx.Conv(*args, **kwargs)[source]#

Convolution Module wrapping lax.conv_general_dilated.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 6, 4)

>>> # circular padding with stride 2
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3),
...                  strides=2, padding='CIRCULAR', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 4, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

int or tuple with number of input features.

out_features#

int or tuple with number of output features.

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

padding#

either the string 'SAME', the string 'VALID', the string 'CIRCULAR' (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides. 'CAUSAL' padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

input_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

feature_group_count#

integer, default 1. If specified divides the input features into groups.

use_bias#

whether to add a bias to the output (default: True).

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the convolutional kernel.

bias_init#

initializer for the bias.

rngs#

rng key.

__call__(inputs)[source]#

Applies a (potentially unshared) convolution to the inputs.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods

class flax.nnx.ConvTranspose(*args, **kwargs)[source]#

Convolution Module wrapping lax.conv_transpose.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> rngs = nnx.Rngs(0)
>>> x = jnp.ones((1, 8, 3))

>>> # valid padding
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,),
...                           padding='VALID', rngs=rngs)
>>> layer.kernel.value.shape
(3, 3, 4)
>>> layer.bias.value.shape
(4,)
>>> out = layer(x)
>>> out.shape
(1, 10, 4)

>>> # circular padding with stride 2
>>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6),
...                           strides=(2, 2), padding='CIRCULAR',
...                           transpose_kernel=True, rngs=rngs)
>>> layer.kernel.value.shape
(6, 6, 4, 3)
>>> layer.bias.value.shape
(4,)
>>> out = layer(jnp.ones((1, 15, 15, 3)))
>>> out.shape
(1, 30, 30, 4)

>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((3, 3, 4)))
>>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,),
...                  mask=mask, padding='VALID', rngs=rngs)
>>> out = layer(x)
in_features#

int or tuple with number of input features.

out_features#

int or tuple with number of output features.

kernel_size#

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

padding#

either the string 'SAME', the string 'VALID', the string 'CIRCULAR' (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides. 'CAUSAL' padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

use_bias#

whether to add a bias to the output (default: True).

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer for the convolutional kernel.

bias_init#

initializer for the bias.

transpose_kernel#

if True flips spatial axes and swaps the input/output channel axes of the kernel.

rngs#

rng key.

__call__(inputs)[source]#

Applies a transposed convolution to the inputs.

Behaviour mirrors of jax.lax.conv_transpose.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods

class flax.nnx.Embed(*args, **kwargs)[source]#

Embedding Module.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
>>> nnx.state(layer)
State({
  'embedding': VariableState(
    type=Param,
    value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
           [ 0.01070483,  0.27923733,  1.7487359 ],
           [ 0.59161806,  0.8660184 ,  1.2838588 ],
           [-0.748139  , -0.15856352,  0.06061118],
           [-0.4769059 , -0.6607095 ,  0.46697947]], dtype=float32)
  )
})
>>> # get the first three and last three embeddings
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> layer(indices_input)
Array([[[-0.90411377, -0.3648777 , -1.1083648 ],
        [ 0.01070483,  0.27923733,  1.7487359 ],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]],

       [[-0.4769059 , -0.6607095 ,  0.46697947],
        [-0.748139  , -0.15856352,  0.06061118],
        [ 0.59161806,  0.8660184 ,  1.2838588 ]]], dtype=float32)

A parameterized function from integers [0, num_embeddings) to features-dimensional vectors. This Module will create an embedding matrix with shape (num_embeddings, features). When calling this layer, the input values will be used to 0-index into the embedding matrix. Indexing on a value greater than or equal to num_embeddings will result in nan values. When num_embeddings equals to 1, it will broadcast the embedding matrix to input shape with features dimension appended.

num_embeddings#

number of embeddings / vocab size.

features#

number of feature dimensions for each embedding.

dtype#

the dtype of the embedding vectors (default: same as embedding).

param_dtype#

the dtype passed to parameter initializers (default: float32).

embedding_init#

embedding initializer.

rngs#

rng key.

__call__(inputs)[source]#

Embeds the inputs along the last dimension.

Parameters

inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.

Returns

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

attend(query)[source]#

Attend over the embedding using a query array.

Parameters

query – array with last dimension equal the feature depth features of the embedding.

Returns

An array with final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.

Methods

attend(query)

Attend over the embedding using a query array.

class flax.nnx.Linear(*args, **kwargs)[source]#

A linear transformation applied over the last dimension of the input.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': VariableState(
    type=Param,
    value=(4,)
  ),
  'kernel': VariableState(
    type=Param,
    value=(3, 4)
  )
})
in_features#

the number of input features.

out_features#

the number of output features.

use_bias#

whether to add a bias to the output (default: True).

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

dot_general#

dot product function.

rngs#

rng key.

__call__(inputs)[source]#

Applies a linear transformation to the inputs along the last dimension.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods

class flax.nnx.LinearGeneral(*args, **kwargs)[source]#

A linear transformation with flexible axes.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # equivalent to `nnx.Linear(2, 4)`
>>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4)
>>> # output features (4, 5)
>>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> y = layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)
in_features#

int or tuple with number of input features.

out_features#

int or tuple with number of output features.

axis#

int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.

batch_axis#

mapping of batch axis indices to axis size.

use_bias#

whether to add a bias to the output (default: True).

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

precision#

numerical precision of the computation see jax.lax.Precision for details.

rngs#

rng key.

__call__(inputs)[source]#

Applies a linear transformation to the inputs along multiple dimensions.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods

class flax.nnx.Einsum(*args, **kwargs)[source]#

An einsum transformation with learnable kernel and bias.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(8, 2, 4)
>>> layer.bias.value.shape
(8, 4)
>>> y = layer(jnp.ones((16, 11, 2)))
>>> y.shape
(16, 11, 8, 4)
einsum_str#

a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of einsum_str in the constructor argument and call argument must be not None, while the other must be None.

kernel_shape#

the shape of the kernel.

bias_shape#

the shape of the bias. If this is None, a bias won’t be used.

dtype#

the dtype of the computation (default: infer from input and params).

param_dtype#

the dtype passed to parameter initializers (default: float32).

precision#

numerical precision of the computation see jax.lax.Precision for details.

kernel_init#

initializer function for the weight matrix.

bias_init#

initializer function for the bias.

rngs#

rng key.

__call__(inputs, einsum_str=None)[source]#

Applies a linear transformation to the inputs along the last dimension.

Parameters
  • inputs – The nd-array to be transformed.

  • einsum_str – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of einsum_str in the constructor argument and call argument must be not None, while the other must be None.

Returns

The transformed input.

Methods