flax.linen.ConvLocal#

class flax.linen.ConvLocal(features, kernel_size, strides=1, padding='SAME', input_dilation=1, kernel_dilation=1, feature_group_count=1, use_bias=True, mask=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, conv_general_dilated=None, conv_general_dilated_cls=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Local convolution Module wrapping lax.conv_general_dilated_local.

Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> # valid padding
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}}
>>> out.shape
(1, 6, 4)
>>> # circular padding with stride 2
>>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR')
>>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3)))
>>> jax.tree_util.tree_map(jnp.shape, variables)
{'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}}
>>> out.shape
(1, 4, 4)
>>> # apply lower triangle mask
>>> mask = jnp.tril(jnp.ones((6, 9, 4)))
>>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID')
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3)))
features#

number of convolution filters.

Type

int

kernel_size#

shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer.

Type

Union[int, Sequence[int]]

strides#

an integer or a sequence of n integers, representing the inter-window strides (default: 1).

Type

Union[None, int, Sequence[int]]

padding#

either the string 'SAME', the string 'VALID', the string 'CIRCULAR' (periodic boundary conditions), or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. 'CAUSAL' padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.

Type

Union[str, int, Sequence[Union[int, Tuple[int, int]]]]

input_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of inputs (default: 1). Convolution with input dilation d is equivalent to transposed convolution with stride d.

Type

Union[None, int, Sequence[int]]

kernel_dilation#

an integer or a sequence of n integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.

Type

Union[None, int, Sequence[int]]

feature_group_count#

integer, default 1. If specified divides the input features into groups.

Type

int

use_bias#

whether to add a bias to the output (default: True).

Type

bool

mask#

Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.

Type

Optional[Union[jax.Array, Any]]

dtype#

the dtype of the computation (default: infer from input and params).

Type

Optional[Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]]

param_dtype#

the dtype passed to parameter initializers (default: float32).

Type

Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, Any]

precision#

numerical precision of the computation see jax.lax.Precision for details.

Type

Union[None, str, jax._src.lax.lax.Precision, Tuple[str, str], Tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]

kernel_init#

initializer for the convolutional kernel.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

bias_init#

initializer for the bias.

Type

Union[jax.nn.initializers.Initializer, Callable[[…], Any]]

__call__(inputs)#

Applies a (potentially unshared) convolution to the inputs.

Parameters

inputs – input data with dimensions (*batch_dims, spatial_dims…, features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by lax.conv_general_dilated, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.

Returns

The convolved data.

Methods