flax.linen.Embed#

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Embedding Module.

A parameterized function from integers [0, num_embeddings) to features-dimensional vectors. This Module will create an embedding matrix with shape (num_embeddings, features). When calling this layer, the input values will be used to 0-index into the embedding matrix. Indexing on a value greater than or equal to num_embeddings will result in nan values. When num_embeddings equals to 1, it will broadcast the embedding matrix to input shape with features dimension appended.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Embed(num_embeddings=5, features=3)
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> variables = layer.init(jax.random.key(0), indices_input)
>>> variables
{'params': {'embedding': Array([[-0.28884724,  0.19018005, -0.414205  ],
       [-0.11768015, -0.54618824, -0.3789283 ],
       [ 0.30428642,  0.49511626,  0.01706631],
       [-0.0982546 , -0.43055868,  0.20654906],
       [-0.688412  , -0.46882293,  0.26723292]], dtype=float32)}}
>>> # get the first three and last three embeddings
>>> layer.apply(variables, indices_input)
Array([[[-0.28884724,  0.19018005, -0.414205  ],
        [-0.11768015, -0.54618824, -0.3789283 ],
        [ 0.30428642,  0.49511626,  0.01706631]],

       [[-0.688412  , -0.46882293,  0.26723292],
        [-0.0982546 , -0.43055868,  0.20654906],
        [ 0.30428642,  0.49511626,  0.01706631]]], dtype=float32)
num_embeddings#

number of embeddings / vocab size.

Type

int

features#

number of feature dimensions for each embedding.

Type

int

dtype#

the dtype of the embedding vectors (default: same as embedding).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

embedding_init#

embedding initializer.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

__call__(inputs)[source]#

Embeds the inputs along the last dimension.

Parameters

inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.

Returns

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

attend(query)[source]#

Attend over the embedding using a query array.

Parameters

query – array with last dimension equal the feature depth features of the embedding.

Returns

An array with final dim num_embeddings corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.

Methods

attend(query)

Attend over the embedding using a query array.