flax.linen.Embed

class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Embedding Module.

A parameterized function from integers [0, n) to d-dimensional vectors.

Parameters
  • num_embeddings (int) –

  • features (int) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • embedding_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

num_embeddings

number of embeddings.

Type

int

features

number of feature dimensions for each embedding.

Type

int

dtype

the dtype of the embedding vectors (default: same as embedding).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

embedding_init

embedding initializer.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)[source]

Embeds the inputs along the last dimension.

Parameters

inputs (Any) – input data, all dimensions are considered batch dimensions.

Returns

Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.

Return type

Any

Methods

attend(query)

Attend over the embedding using a query array.

embedding_init(shape[, dtype])

setup()

Initializes a Module lazily (similar to a lazy __init__).