flax.linen.Embed¶
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
- Parameters
num_embeddings (int) –
features (int) –
dtype (Optional[Any]) –
param_dtype (Any) –
embedding_init (Callable[[Any, Tuple[int, ...], Any], Any]) –
parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
name (str) –
- Return type
None
- num_embeddings¶
number of embeddings.
- Type
int
- features¶
number of feature dimensions for each embedding.
- Type
int
- dtype¶
the dtype of the embedding vectors (default: same as embedding).
- Type
Optional[Any]
- param_dtype¶
the dtype passed to parameter initializers (default: float32).
- Type
Any
- embedding_init¶
embedding initializer.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- __call__(inputs)[source]¶
Embeds the inputs along the last dimension.
- Parameters
inputs (Any) – input data, all dimensions are considered batch dimensions.
- Returns
Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.
- Return type
Any
Methods
attend
(query)Attend over the embedding using a query array.
embedding_init
(shape[, dtype])setup
()Initializes a Module lazily (similar to a lazy
__init__
).