flax.linen.ConvTranspose#
- class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, transpose_kernel=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Convolution Module wrapping lax.conv_transpose.
Example usage:
>>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) >>> jax.tree_util.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
- features#
number of convolution filters.
- Type
int
- kernel_size#
shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.
- Type
Union[int, Sequence[int]]
- strides#
an integer or a sequence of n integers, representing the inter-window strides.
- Type
Optional[Sequence[int]]
- padding#
either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.
- Type
Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
- kernel_dilation#
None
, or an integer or a sequence ofn
integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.- Type
Optional[Sequence[int]]
- use_bias#
whether to add a bias to the output (default: True).
- Type
bool
- mask#
Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.
- Type
Optional[Union[jax.Array, Any]]
- dtype#
the dtype of the computation (default: infer from input and params).
- Type
Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]
- param_dtype#
the dtype passed to parameter initializers (default: float32).
- Type
Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]
- precision#
numerical precision of the computation see
jax.lax.Precision
for details.- Type
Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]
- kernel_init#
initializer for the convolutional kernel.
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- bias_init#
initializer for the bias.
- Type
Union[jax.nn.initializers.Initializer, Callable[[…], Any]]
- transpose_kernel#
if True flips spatial axes and swaps the input/output channel axes of the kernel.
- Type
bool
- __call__(inputs)[source]#
Applies a transposed convolution to the inputs.
Behaviour mirrors of
jax.lax.conv_transpose
.- Parameters
inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by
lax.conv_general_dilated
, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.- Returns
The convolved data.
Methods