class flax.nn.Embed(inputs, num_embeddings, features, embedding_init=<function variance_scaling.<locals>.init>)[source]

Embeds the inputs along the last dimension.

  • inputs – input data, all dimensions are considered batch dimensions.

  • num_embeddings – number of embeddings.

  • features – Number of feature dimensions for each embedding.

  • embedding_init – embedding initializer.


Output which is embedded input data. The output shape follows the input, with an additional features dimension appended.




apply(inputs, num_embeddings, features[, …])

Embeds the inputs along the last dimension.

call(inputs, num_embeddings, features[, …])

Evaluate the module with the given parameters.

create(inputs, num_embeddings, features[, …])

Creates a module instance by evaluating the model.

create_by_shape(input_specs, inputs, …[, …])

Creates a module instance using only shape and dtype information.


Retrieves a parameter within the module’s apply function.

init(inputs, num_embeddings, features[, …])

Initializes the module parameters.

init_by_shape(input_specs, inputs, …[, …])

Initialize the module parameters.



param(name, shape, initializer)

Defines a parameter within the module’s apply function.

partial(num_embeddings, features[, …])

Partially applies a module with the given arguments.

shared(*[, name])

Partially applies a module and shared parameters for each call.

state(name[, shape, initializer, collection])

Declare a state variable within the module’s apply function.