flax.linen.Embed
flax.linen.Embed#
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
- num_embeddings#
number of embeddings.
- Type
int
- features#
number of feature dimensions for each embedding.
- Type
int
- dtype#
the dtype of the embedding vectors (default: same as embedding).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- embedding_init#
embedding initializer.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- __call__(inputs)[source]#
Embeds the inputs along the last dimension.
- Parameters
inputs – input data, all dimensions are considered batch dimensions.
- Returns
Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.
Methods
attend
(query)Attend over the embedding using a query array.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).