flax.linen.Embed#
- class flax.linen.Embed(num_embeddings, features, dtype=None, param_dtype=<class 'jax.numpy.float32'>, embedding_init=<function variance_scaling.<locals>.init>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Embedding Module.
A parameterized function from integers [0,
num_embeddings
) tofeatures
-dimensional vectors. ThisModule
will create anembedding
matrix with shape(num_embeddings, features)
. When calling this layer, the input values will be used to 0-index into theembedding
matrix. Indexing on a value greater than or equal tonum_embeddings
will result innan
values. Whennum_embeddings
equals to 1, it will broadcast theembedding
matrix to input shape withfeatures
dimension appended.Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Embed(num_embeddings=5, features=3) >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) >>> variables = layer.init(jax.random.key(0), indices_input) >>> variables {'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ], [-0.11768015, -0.54618824, -0.3789283 ], [ 0.30428642, 0.49511626, 0.01706631], [-0.0982546 , -0.43055868, 0.20654906], [-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}} >>> # get the first three and last three embeddings >>> layer.apply(variables, indices_input) Array([[[-0.28884724, 0.19018005, -0.414205 ], [-0.11768015, -0.54618824, -0.3789283 ], [ 0.30428642, 0.49511626, 0.01706631]], [[-0.688412 , -0.46882293, 0.26723292], [-0.0982546 , -0.43055868, 0.20654906], [ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)
- num_embeddings#
number of embeddings / vocab size.
- Type
int
- features#
number of feature dimensions for each embedding.
- Type
int
- dtype#
the dtype of the embedding vectors (default: same as embedding).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- embedding_init#
embedding initializer.
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- __call__(inputs)[source]#
Embeds the inputs along the last dimension.
- Parameters
inputs – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.
- Returns
Output which is embedded input data. The output shape follows the input, with an additional
features
dimension appended.
- attend(query)[source]#
Attend over the embedding using a query array.
- Parameters
query – array with last dimension equal the feature depth
features
of the embedding.- Returns
An array with final dim
num_embeddings
corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.
Methods
attend
(query)Attend over the embedding using a query array.