flax.linen.Conv#
- class flax.linen.Conv(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=<function conv_general_dilated>, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Convolution Module wrapping lax.conv_general_dilated.
- features#
number of convolution filters.
- Type
int
- kernel_size#
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.
- Type
Sequence[int]
- strides#
an integer or a sequence of n integers, representing the inter-window strides (default: 1).
- Type
Union[None, int, Sequence[int]]
- padding#
either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. ‘CAUSAL’ padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.
- Type
Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
- input_dilation#
an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.
- Type
Union[None, int, Sequence[int]]
- kernel_dilation#
an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.
- Type
Union[None, int, Sequence[int]]
- feature_group_count#
integer, default 1. If specified divides the input features into groups.
- Type
int
- use_bias#
whether to add a bias to the output (default: True).
- Type
bool
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- Type
Optional[Any]
- dtype#
the dtype of the computation (default: infer from input and params).
- Type
Optional[Any]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Any
- precision#
numerical precision of the computation see jax.lax.Precision for details.
- Type
Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
initializer for the convolutional kernel.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- bias_init#
initializer for the bias.
- Type
Callable[[Any, Tuple[int, …], Any], Any]
- __call__(inputs)#
Applies a (potentially unshared) convolution to the inputs.
- Parameters
inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.
- Returns
The convolved data.
Methods