Linear#
NNX linear layer classes.
- class flax.nnx.Conv(*args, **kwargs)[source]#
Convolution Module wrapping
lax.conv_general_dilated
.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3, 3), ... strides=2, padding='CIRCULAR', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
int or tuple with number of input features.
- out_features#
int or tuple with number of output features.
- kernel_size#
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
- strides#
an integer or a sequence of
n
integers, representing the inter-window strides (default: 1).
- padding#
either the string
'SAME'
, the string'VALID'
, the string'CIRCULAR'
(periodic boundary conditions), or a sequence ofn
(low, high)
integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides.'CAUSAL'
padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.
- input_dilation#
an integer or a sequence of
n
integers, giving the dilation factor to apply in each spatial dimension ofinputs
(default: 1). Convolution with input dilationd
is equivalent to transposed convolution with strided
.
- kernel_dilation#
an integer or a sequence of
n
integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.
- feature_group_count#
integer, default 1. If specified divides the input features into groups.
- use_bias#
whether to add a bias to the output (default: True).
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.
- kernel_init#
initializer for the convolutional kernel.
- bias_init#
initializer for the bias.
- rngs#
rng key.
- __call__(inputs)[source]#
Applies a (potentially unshared) convolution to the inputs.
- Parameters
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features)
. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used bylax.conv_general_dilated
, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns
The convolved data.
Methods
- class flax.nnx.ConvTranspose(*args, **kwargs)[source]#
Convolution Module wrapping
lax.conv_transpose
.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> rngs = nnx.Rngs(0) >>> x = jnp.ones((1, 8, 3)) >>> # valid padding >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(3,), ... padding='VALID', rngs=rngs) >>> layer.kernel.value.shape (3, 3, 4) >>> layer.bias.value.shape (4,) >>> out = layer(x) >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nnx.ConvTranspose(in_features=3, out_features=4, kernel_size=(6, 6), ... strides=(2, 2), padding='CIRCULAR', ... transpose_kernel=True, rngs=rngs) >>> layer.kernel.value.shape (6, 6, 4, 3) >>> layer.bias.value.shape (4,) >>> out = layer(jnp.ones((1, 15, 15, 3))) >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nnx.Conv(in_features=3, out_features=4, kernel_size=(3,), ... mask=mask, padding='VALID', rngs=rngs) >>> out = layer(x)
- in_features#
int or tuple with number of input features.
- out_features#
int or tuple with number of output features.
- kernel_size#
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
- strides#
an integer or a sequence of
n
integers, representing the inter-window strides (default: 1).
- padding#
either the string
'SAME'
, the string'VALID'
, the string'CIRCULAR'
(periodic boundary conditions), or a sequence ofn
(low, high)
integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides.'CAUSAL'
padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.
- kernel_dilation#
an integer or a sequence of
n
integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.
- use_bias#
whether to add a bias to the output (default: True).
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.
- kernel_init#
initializer for the convolutional kernel.
- bias_init#
initializer for the bias.
- transpose_kernel#
if
True
flips spatial axes and swaps the input/output channel axes of the kernel.
- rngs#
rng key.
- __call__(inputs)[source]#
Applies a transposed convolution to the inputs.
Behaviour mirrors of
jax.lax.conv_transpose
.- Parameters
inputs – input data with dimensions
(*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by ``lax.conv_general_dilated
, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns
The convolved data.
Methods
- class flax.nnx.Embed(*args, **kwargs)[source]#
Embedding Module.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'embedding': VariableState( type=Param, value=Array([[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ], [-0.748139 , -0.15856352, 0.06061118], [-0.4769059 , -0.6607095 , 0.46697947]], dtype=float32) ) }) >>> # get the first three and last three embeddings >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> layer(indices_input) Array([[[-0.90411377, -0.3648777 , -1.1083648 ], [ 0.01070483, 0.27923733, 1.7487359 ], [ 0.59161806, 0.8660184 , 1.2838588 ]], [[-0.4769059 , -0.6607095 , 0.46697947], [-0.748139 , -0.15856352, 0.06061118], [ 0.59161806, 0.8660184 , 1.2838588 ]]], dtype=float32)
A parameterized function from integers [0,
num_embeddings
) tofeatures
-dimensional vectors. ThisModule
will create anembedding
matrix with shape(num_embeddings, features)
. When calling this layer, the input values will be used to 0-index into theembedding
matrix. Indexing on a value greater than or equal tonum_embeddings
will result innan
values. Whennum_embeddings
equals to 1, it will broadcast theembedding
matrix to input shape withfeatures
dimension appended.- num_embeddings#
number of embeddings / vocab size.
- features#
number of feature dimensions for each embedding.
- dtype#
the dtype of the embedding vectors (default: same as embedding).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- embedding_init#
embedding initializer.
- rngs#
rng key.
- __call__(inputs)[source]#
Embeds the inputs along the last dimension.
- Parameters
inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.
- Returns
Output which is embedded input data. The output shape follows the input, with an additional
features
dimension appended.
- attend(query)[source]#
Attend over the embedding using a query array.
- Parameters
query – array with last dimension equal the feature depth
features
of the embedding.- Returns
An array with final dim
num_embeddings
corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.
Methods
attend
(query)Attend over the embedding using a query array.
- class flax.nnx.Linear(*args, **kwargs)[source]#
A linear transformation applied over the last dimension of the input.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': VariableState( type=Param, value=(4,) ), 'kernel': VariableState( type=Param, value=(3, 4) ) })
- in_features#
the number of input features.
- out_features#
the number of output features.
- use_bias#
whether to add a bias to the output (default: True).
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.
- kernel_init#
initializer function for the weight matrix.
- bias_init#
initializer function for the bias.
- dot_general#
dot product function.
- rngs#
rng key.
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters
inputs – The nd-array to be transformed.
- Returns
The transformed input.
Methods
- class flax.nnx.LinearGeneral(*args, **kwargs)[source]#
A linear transformation with flexible axes.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp ... >>> # equivalent to `nnx.Linear(2, 4)` >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4) >>> # output features (4, 5) >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 4, 5) >>> layer.bias.value.shape (4, 5) >>> # apply transformation on the the second and last axes >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (2, 3, 4, 5) >>> layer.bias.value.shape (4, 5) >>> y = layer(jnp.ones((16, 2, 3))) >>> y.shape (16, 4, 5)
- in_features#
int or tuple with number of input features.
- out_features#
int or tuple with number of output features.
- axis#
int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.
- batch_axis#
mapping of batch axis indices to axis size.
- use_bias#
whether to add a bias to the output (default: True).
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- kernel_init#
initializer function for the weight matrix.
- bias_init#
initializer function for the bias.
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.
- rngs#
rng key.
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along multiple dimensions.
- Parameters
inputs – The nd-array to be transformed.
- Returns
The transformed input.
Methods
- class flax.nnx.Einsum(*args, **kwargs)[source]#
An einsum transformation with learnable kernel and bias.
Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) >>> layer.kernel.value.shape (8, 2, 4) >>> layer.bias.value.shape (8, 4) >>> y = layer(jnp.ones((16, 11, 2))) >>> y.shape (16, 11, 8, 4)
- einsum_str#
a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of
einsum_str
in the constructor argument and call argument must be not None, while the other must be None.
- kernel_shape#
the shape of the kernel.
- bias_shape#
the shape of the bias. If this is None, a bias won’t be used.
- dtype#
the dtype of the computation (default: infer from input and params).
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.
- kernel_init#
initializer function for the weight matrix.
- bias_init#
initializer function for the bias.
- rngs#
rng key.
- __call__(inputs, einsum_str=None)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters
inputs – The nd-array to be transformed.
einsum_str – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of
einsum_str
in the constructor argument and call argument must be not None, while the other must be None.
- Returns
The transformed input.
Methods