flax.linen.Dense¶
- class flax.linen.Dense(features, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
A linear transformation applied over the last dimension of the input.
- Parameters
features (int) –
use_bias (bool) –
dtype (Optional[Any]) –
param_dtype (Any) –
precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –
kernel_init (Callable[[Any, Tuple[int, ...], Any], Any]) –
bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –
parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
name (str) –
- Return type
None
- features¶
the number of output features.
- Type
int
- use_bias¶
whether to add a bias to the output (default: True).
- Type
bool
- dtype¶
the dtype of the computation (default: infer from input and params).
- Type
Optional[Any]
- param_dtype¶
the dtype passed to parameter initializers (default: float32).
- Type
Any
- precision¶
numerical precision of the computation see jax.lax.Precision for details.
- Type
Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init¶
initializer function for the weight matrix.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- bias_init¶
initializer function for the bias.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- __call__(inputs)[source]¶
Applies a linear transformation to the inputs along the last dimension.
- Parameters
inputs (Any) – The nd-array to be transformed.
- Returns
The transformed input.
- Return type
Any
Methods
bias_init
(shape[, dtype])An initializer that returns a constant array full of zeros.
kernel_init
(shape[, dtype])