flax.linen.Embed#

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Embedding Module.

A parameterized function from integers [0, n) to d-dimensional vectors.

num_embeddings#

number of embeddings.

Type

int

features#

number of feature dimensions for each embedding.

Type

int

dtype#

the dtype of the embedding vectors (default: same as embedding).

Type

Optional[Any]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Any

embedding_init#

embedding initializer.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)[source]#

Embeds the inputs along the last dimension.

Parameters

inputs – input data, all dimensions are considered batch dimensions.

Returns

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

Methods

attend(query)

Attend over the embedding using a query array.

setup()

Initializes a Module lazily (similar to a lazy __init__).