flax.linen.DenseGeneral#
- class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, dot_general=<function dot_general>, dot_general_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A linear transformation with flexible axes.
- features#
int or tuple with number of output features.
- Type
Union[int, Sequence[int]]
- axis#
int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.
- Type
Union[int, Sequence[int]]
- batch_dims#
tuple with batch axes.
- Type
Sequence[int]
- use_bias#
whether to add a bias to the output (default: True).
- Type
bool
- dtype#
the dtype of the computation (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- kernel_init#
initializer function for the weight matrix.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- bias_init#
initializer function for the bias.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- Type
Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- __call__(inputs)[source]#
Applies a linear transformation to the inputs along multiple dimensions.
- Parameters
inputs – The nd-array to be transformed.
- Returns
The transformed input.
Methods