flax.linen.Dense#

class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, dot_general=None, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A linear transformation applied over the last dimension of the input.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> layer = nn.Dense(features=4)
>>> params = layer.init(jax.random.key(0), jnp.ones((1, 3)))
>>> jax.tree_util.tree_map(jnp.shape, params)
{'params': {'bias': (4,), 'kernel': (3, 4)}}
features#

the number of output features.

Type

int

use_bias#

whether to add a bias to the output (default: True).

Type

bool

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer function for the weight matrix.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

bias_init#

initializer function for the bias.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

__call__(inputs)[source]#

Applies a linear transformation to the inputs along the last dimension.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods