flax.linen.DenseGeneral

class flax.linen.DenseGeneral(features, axis=-1, batch_dims=(), use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, precision=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]

A linear transformation with flexible axes.

Parameters
  • features (Union[int, Sequence[int]]) –

  • axis (Union[int, Sequence[int]]) –

  • batch_dims (Sequence[int]) –

  • use_bias (bool) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • kernel_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

features

int or tuple with number of output features.

Type

Union[int, Sequence[int]]

axis

int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.

Type

Union[int, Sequence[int]]

batch_dims

tuple with batch axes.

Type

Sequence[int]

use_bias

whether to add a bias to the output (default: True).

Type

bool

dtype

the dtype of the computation (default: infer from input and params).

Type

Optional[Any]

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

kernel_init

initializer function for the weight matrix.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init

initializer function for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

precision

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

__call__(inputs)[source]

Applies a linear transformation to the inputs along multiple dimensions.

Parameters

inputs (Any) – The nd-array to be transformed.

Returns

The transformed input.

Return type

Any

Methods

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

kernel_init(shape[, dtype])