flax.linen.ConvTranspose

class flax.linen.ConvTranspose(features, kernel_size, strides=None, padding='SAME', kernel_dilation=None, use_bias=True, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]

Convolution Module wrapping lax.conv_transpose.

Parameters
  • features (int) –

  • kernel_size (Union[int, Tuple[int, ...]]) –

  • strides (Optional[Tuple[int, ...]]) –

  • padding (Union[str, int, Sequence[Union[int, Tuple[int, int]]]]) –

  • kernel_dilation (Optional[Sequence[int]]) –

  • use_bias (bool) –

  • dtype (Any) –

  • param_dtype (Any) –

  • precision (Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]) –

  • kernel_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • bias_init (Callable[[Any, Tuple[int, ...], Any], Any]) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

features

number of convolution filters.

Type

int

kernel_size

shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers.

Type

Union[int, Tuple[int, …]]

strides

a sequence of n integers, representing the inter-window strides.

Type

Optional[Tuple[int, …]]

padding

either the string ‘SAME’, the string ‘VALID’, the string ‘CIRCULAR’ (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpeted as applying the same padding in all dims and passign a single int in a sequence causes the same padding to be used on both sides.

Type

Union[str, int, Sequence[Union[int, Tuple[int, int]]]]

kernel_dilation

None, or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.

Type

Optional[Sequence[int]]

use_bias

whether to add a bias to the output (default: True).

Type

bool

dtype

the dtype of the computation (default: infer from input and params).

Type

Any

param_dtype

the dtype passed to parameter initializers (default: float32).

Type

Any

precision

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init

initializer for the convolutional kernel.

Type

Callable[[Any, Tuple[int, …], Any], Any]

bias_init

initializer for the bias.

Type

Callable[[Any, Tuple[int, …], Any], Any]

__call__(inputs)[source]

Applies a transposed convolution to the inputs.

Behaviour mirrors of jax.lax.conv_transpose.

Parameters

inputs (Any) – input data with dimensions (batch, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last.

Returns

The convolved data.

Return type

Any

Methods

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

kernel_init(shape[, dtype])