flax.linen.Conv

class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Convolution Module wrapping lax.conv_general_dilated.

Parameters
  • features (int) –

  • kernel_size (Sequence[int]) –

  • strides (Union[None, int, Sequence[int]]) –

  • padding (Union[str, int, Sequence[Union[int, Tuple[int, int]]]]) –

  • input_dilation (Union[None, int, Sequence[int]]) –

  • kernel_dilation (Union[None, int, Sequence[int]]) –

  • feature_group_count (int) –

  • use_bias (bool) –

  • dtype (Optional[Any]) –

  • param_dtype (Any) –

  • precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –

  • kernel_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

__call__(inputs)

Applies a (potentially unshared) convolution to the inputs.

Parameters

inputs (Any) – input data with dimensions (batch, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last.

Returns

The convolved data.

Return type

Any

Methods