Source code for flax.linen.linear

# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Linear modules."""

import dataclasses
from typing import (
  Any,
  Callable,
  Iterable,
  List,
  Optional,
  Sequence,
  Tuple,
  Union,
)

import jax
import jax.numpy as jnp
import numpy as np
from jax import eval_shape, lax
from jax.core import ShapedArray

from flax.core import meta
from flax.linen import initializers
from flax.linen.dtypes import promote_dtype
from flax.linen.module import Module, compact

PRNGKey = Any
Shape = Tuple[int, ...]
Dtype = Any  # this could be a real type?
Array = Any
PrecisionLike = Union[
  None,
  str,
  lax.Precision,
  Tuple[str, str],
  Tuple[lax.Precision, lax.Precision],
]
DotGeneralT = Callable[..., Array]
ConvGeneralDilatedT = Callable[..., Array]

default_kernel_init = initializers.lecun_normal()


def _normalize_axes(axes: Tuple[int, ...], ndim: int) -> Tuple[int, ...]:
  # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
  return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes))


def _canonicalize_tuple(x: Union[Sequence[int], int]) -> Tuple[int, ...]:
  if isinstance(x, Iterable):
    return tuple(x)
  else:
    return (x,)


[docs]class DenseGeneral(Module): """A linear transformation with flexible axes. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # equivalent to `nn.Dense(features=4)` >>> layer = nn.DenseGeneral(features=4) >>> # output features (4, 5) >>> layer = nn.DenseGeneral(features=(4, 5)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} >>> # apply transformation on the the second and last axes >>> layer = nn.DenseGeneral(features=(4, 5), axis=(1, -1)) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) >>> jax.tree_map(jnp.shape, params) {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}} Attributes: features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. precision: numerical precision of the computation see `jax.lax.Precision` for details. """ features: Union[int, Sequence[int]] axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () use_bias: bool = True dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[ [PRNGKey, Shape, Dtype], Array ] = initializers.zeros_init() precision: PrecisionLike = None # Deprecated. Will be removed. dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None
[docs] @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along multiple dimensions. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) batch_dims = _canonicalize_tuple(self.batch_dims) if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): raise ValueError( 'batch_dims %s must be consecutive leading ' 'dimensions starting from 0.' % str(batch_dims) ) ndim = inputs.ndim n_batch_dims = len(batch_dims) axis = _normalize_axes(axis, ndim) batch_dims = _normalize_axes(batch_dims, ndim) n_axis, n_features = len(axis), len(features) def kernel_init_wrap(rng, shape, dtype=jnp.float32): flat_shape = ( np.prod(shape[:n_batch_dims]) * np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), np.prod(shape[-n_features:]), ) flat_shape = jax.tree_map(int, flat_shape) kernel = self.kernel_init(rng, flat_shape, dtype) if isinstance(kernel, meta.AxisMetadata): return meta.replace_boxed(kernel, jnp.reshape(kernel.unbox(), shape)) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims) # batch and non-contracting dims of input with 1s for batch dims. expanded_batch_shape = tuple( inputs.shape[ax] if ax in batch_dims else 1 for ax in range(inputs.ndim) if ax not in axis ) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel = self.param( 'kernel', kernel_init_wrap, batch_shape + kernel_shape, self.param_dtype ) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) if self.use_bias: def bias_init_wrap(rng, shape, dtype=jnp.float32): flat_shape = ( np.prod(shape[:n_batch_dims]) * np.prod(shape[-n_features:]), ) flat_shape = jax.tree_map(int, flat_shape) bias = self.bias_init(rng, flat_shape, dtype) if isinstance(bias, meta.AxisMetadata): return meta.replace_boxed(bias, jnp.reshape(bias.unbox(), shape)) return jnp.reshape(bias, shape) bias = self.param( 'bias', bias_init_wrap, batch_shape + features, self.param_dtype ) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general out = dot_general( inputs, kernel, ((axis, contract_ind), (batch_dims, batch_ind)), precision=self.precision, ) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: # expand bias shape to broadcast bias over batch dims. bias = jnp.reshape(bias, expanded_batch_shape + features) out += bias return out
[docs]class Dense(Module): """A linear transformation applied over the last dimension of the input. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Dense(features=4) >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) >>> jax.tree_map(jnp.shape, params) {'params': {'bias': (4,), 'kernel': (3, 4)}} Attributes: features: the number of output features. use_bias: whether to add a bias to the output (default: True). dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer function for the weight matrix. bias_init: initializer function for the bias. """ features: int use_bias: bool = True dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[ [PRNGKey, Shape, Dtype], Array ] = initializers.zeros_init() # Deprecated. Will be removed. dot_general: Optional[DotGeneralT] = None dot_general_cls: Any = None
[docs] @compact def __call__(self, inputs: Array) -> Array: """Applies a linear transformation to the inputs along the last dimension. Args: inputs: The nd-array to be transformed. Returns: The transformed input. """ kernel = self.param( 'kernel', self.kernel_init, (jnp.shape(inputs)[-1], self.features), self.param_dtype, ) if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,), self.param_dtype ) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) if self.dot_general_cls is not None: dot_general = self.dot_general_cls() elif self.dot_general is not None: dot_general = self.dot_general else: dot_general = lax.dot_general y = dot_general( inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision, ) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y
def _conv_dimension_numbers(input_shape): """Computes the dimension numbers based on the input shape.""" ndim = len(input_shape) lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2)) out_spec = lhs_spec return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec) PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] LaxPadding = Union[str, Sequence[Tuple[int, int]]] def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding: """ "Canonicalizes conv padding to a jax.lax supported format.""" if isinstance(padding, str): return padding if isinstance(padding, int): return [(padding, padding)] * rank if isinstance(padding, Sequence) and len(padding) == rank: new_pad = [] for p in padding: if isinstance(p, int): new_pad.append((p, p)) elif isinstance(p, tuple) and len(p) == 2: new_pad.append(p) else: break if len(new_pad) == rank: return new_pad raise ValueError( f'Invalid padding format: {padding}, should be str, int,' f' or a sequence of len {rank} where each element is an' ' int or pair of ints.' ) class _Conv(Module): """Convolution Module wrapping `lax.conv_general_dilated[_local]`. Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string `'SAME'`, the string `'VALID'`, the string `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. `'CAUSAL'` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs` (default: 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. """ features: int kernel_size: Sequence[int] strides: Union[None, int, Sequence[int]] = 1 padding: PaddingLike = 'SAME' input_dilation: Union[None, int, Sequence[int]] = 1 kernel_dilation: Union[None, int, Sequence[int]] = 1 feature_group_count: int = 1 use_bias: bool = True mask: Optional[Array] = None dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[ [PRNGKey, Shape, Dtype], Array ] = initializers.zeros_init() # Deprecated. Will be removed. conv_general_dilated: Optional[ConvGeneralDilatedT] = None conv_general_dilated_cls: Any = None @property def shared_weights(self) -> bool: # type: ignore """Defines whether weights are shared or not between different pixels. Returns: `True` to use shared weights in convolution (regular convolution). `False` to use different weights at different pixels, a.k.a. "locally connected layer", "unshared convolution", or "local convolution". """ ... @compact def __call__(self, inputs: Array) -> Array: """Applies a (potentially unshared) convolution to the inputs. Args: inputs: input data with dimensions (*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by `lax.conv_general_dilated`, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. Returns: The convolved data. """ if isinstance(self.kernel_size, int): raise TypeError( 'Expected Conv kernel_size to be a' ' tuple/list of integers (eg.: [3, 3]) but got' f' {self.kernel_size}.' ) else: kernel_size = tuple(self.kernel_size) def maybe_broadcast( x: Optional[Union[int, Sequence[int]]] ) -> Tuple[int, ...]: if x is None: # backward compatibility with using None as sentinel for # broadcast 1 x = 1 if isinstance(x, int): return (x,) * len(kernel_size) return tuple(x) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] total_batch_size = int(np.prod(input_batch_shape)) flat_input_shape = (total_batch_size,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) # self.strides or (1,) * (inputs.ndim - 2) strides = maybe_broadcast(self.strides) input_dilation = maybe_broadcast(self.input_dilation) kernel_dilation = maybe_broadcast(self.kernel_dilation) padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax == 'CIRCULAR': kernel_size_dilated = [ (k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation) ] zero_pad: List[Tuple[int, int]] = [(0, 0)] pads = ( zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)] ) inputs = jnp.pad(inputs, pads, mode='wrap') padding_lax = 'VALID' elif padding_lax == 'CAUSAL': if len(kernel_size) != 1: raise ValueError( 'Causal padding is only implemented for 1D convolutions.' ) left_pad = kernel_dilation[0] * (kernel_size[0] - 1) pads = [(0, 0), (left_pad, 0), (0, 0)] inputs = jnp.pad(inputs, pads) padding_lax = 'VALID' dimension_numbers = _conv_dimension_numbers(inputs.shape) in_features = jnp.shape(inputs)[-1] if self.shared_weights: # One shared convolutional kernel for all pixels in the output. assert in_features % self.feature_group_count == 0 kernel_shape = kernel_size + ( in_features // self.feature_group_count, self.features, ) else: if self.feature_group_count != 1: raise NotImplementedError( '`lax.conv_general_dilated_local` does not support ' f'`feature_group_count != 1`, got `{self.feature_group_count}`.' ) # Need to know the spatial output shape of a standard convolution to # create the unshared convolution kernel. if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated else: conv_general_dilated = lax.conv_general_dilated conv_output_shape = eval_shape( lambda lhs, rhs: conv_general_dilated( # pylint: disable=g-long-lambda lhs=lhs, rhs=rhs, window_strides=strides, padding=padding_lax, dimension_numbers=dimension_numbers, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, ), inputs, ShapedArray(kernel_size + (in_features, self.features), inputs.dtype), ).shape # One (unshared) convolutional kernel per each pixel in the output. kernel_shape = conv_output_shape[1:-1] + ( np.prod(kernel_size) * in_features, self.features, ) if self.mask is not None and self.mask.shape != kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {kernel_shape}' ) kernel = self.param( 'kernel', self.kernel_init, kernel_shape, self.param_dtype ) if self.mask is not None: kernel *= self.mask if self.use_bias: if self.shared_weights: # One bias weight per output channel, shared between pixels. bias_shape = (self.features,) else: # One bias weight per output entry, unshared betwen pixels. bias_shape = conv_output_shape[1:] bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) if self.shared_weights: if self.conv_general_dilated_cls is not None: conv_general_dilated = self.conv_general_dilated_cls() elif self.conv_general_dilated is not None: conv_general_dilated = self.conv_general_dilated else: conv_general_dilated = lax.conv_general_dilated y = conv_general_dilated( inputs, kernel, strides, padding_lax, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, feature_group_count=self.feature_group_count, precision=self.precision, ) else: y = lax.conv_general_dilated_local( lhs=inputs, rhs=kernel, window_strides=strides, padding=padding_lax, filter_shape=kernel_size, lhs_dilation=input_dilation, rhs_dilation=kernel_dilation, dimension_numbers=dimension_numbers, precision=self.precision, ) if self.use_bias: bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape) y += bias if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y
[docs]class Conv(_Conv): """Convolution Module wrapping `lax.conv_general_dilated`. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.Conv(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.Conv(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.Conv(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string `'SAME'`, the string `'VALID'`, the string `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. `'CAUSAL'` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs` (default: 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. """ @property def shared_weights(self) -> bool: return True
[docs]class ConvLocal(_Conv): """Local convolution Module wrapping `lax.conv_general_dilated_local`. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}} >>> out.shape (1, 6, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvLocal(features=4, kernel_size=(3, 3), strides=2, padding='CIRCULAR') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}} >>> out.shape (1, 4, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((6, 9, 4))) >>> layer = nn.ConvLocal(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. strides: an integer or a sequence of `n` integers, representing the inter-window strides (default: 1). padding: either the string `'SAME'`, the string `'VALID'`, the string `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. `'CAUSAL'` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `inputs` (default: 1). Convolution with input dilation `d` is equivalent to transposed convolution with stride `d`. kernel_dilation: an integer or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as 'atrous convolution'. feature_group_count: integer, default 1. If specified divides the input features into groups. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. """ @property def shared_weights(self) -> bool: return False
[docs]class ConvTranspose(Module): """Convolution Module wrapping lax.conv_transpose. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> # valid padding >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), padding='VALID') >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 8, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (3, 3, 4)}} >>> out.shape (1, 10, 4) >>> # circular padding with stride 2 >>> layer = nn.ConvTranspose(features=4, kernel_size=(6, 6), strides=(2, 2), padding='CIRCULAR', transpose_kernel=True) >>> out, variables = layer.init_with_output(jax.random.key(0), jnp.ones((1, 15, 15, 3))) >>> jax.tree_map(jnp.shape, variables) {'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}} >>> out.shape (1, 30, 30, 4) >>> # apply lower triangle mask >>> mask = jnp.tril(jnp.ones((3, 3, 4))) >>> layer = nn.ConvTranspose(features=4, kernel_size=(3,), mask=mask, padding='VALID') >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 8, 3))) Attributes: features: number of convolution filters. kernel_size: shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer. For all other cases, it must be a sequence of integers. strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, the string `'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. kernel_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as 'atrous convolution'. use_bias: whether to add a bias to the output (default: True). mask: Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix. dtype: the dtype of the computation (default: infer from input and params). param_dtype: the dtype passed to parameter initializers (default: float32). precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the convolutional kernel. bias_init: initializer for the bias. transpose_kernel: if True flips spatial axes and swaps the input/output channel axes of the kernel. """ features: int kernel_size: Union[int, Sequence[int]] strides: Optional[Sequence[int]] = None padding: PaddingLike = 'SAME' kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True mask: Optional[Array] = None dtype: Dtype = None param_dtype: Dtype = jnp.float32 precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init bias_init: Callable[ [PRNGKey, Shape, Dtype], Array ] = initializers.zeros_init() transpose_kernel: bool = False
[docs] @compact def __call__(self, inputs: Array) -> Array: """Applies a transposed convolution to the inputs. Behaviour mirrors of `jax.lax.conv_transpose`. Args: inputs: input data with dimensions (*batch_dims, spatial_dims..., features). This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by `lax.conv_general_dilated`, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap'ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code. Returns: The convolved data. """ kernel_size: Tuple[int, ...] if isinstance(self.kernel_size, int): kernel_size = (self.kernel_size,) else: kernel_size = tuple(self.kernel_size) # Combine all input batch dimensions into a single leading batch axis. num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1) if num_batch_dimensions != 1: input_batch_shape = inputs.shape[:num_batch_dimensions] total_batch_size = int(np.prod(input_batch_shape)) flat_input_shape = (total_batch_size,) + inputs.shape[ num_batch_dimensions: ] inputs = jnp.reshape(inputs, flat_input_shape) strides: Tuple[int, ...] if self.strides is None: strides = (1,) * (inputs.ndim - 2) else: strides = tuple(self.strides) in_features = jnp.shape(inputs)[-1] if self.transpose_kernel: kernel_shape = kernel_size + (self.features, in_features) else: kernel_shape = kernel_size + (in_features, self.features) if self.mask is not None and self.mask.shape != kernel_shape: raise ValueError( 'Mask needs to have the same shape as weights. ' f'Shapes are: {self.mask.shape}, {kernel_shape}' ) kernel = self.param( 'kernel', self.kernel_init, kernel_shape, self.param_dtype ) if self.mask is not None: kernel *= self.mask padding_lax = canonicalize_padding(self.padding, len(kernel_size)) if padding_lax == 'CIRCULAR': padding_lax = 'VALID' if self.use_bias: bias = self.param( 'bias', self.bias_init, (self.features,), self.param_dtype ) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) y = lax.conv_transpose( inputs, kernel, strides, padding_lax, rhs_dilation=self.kernel_dilation, transpose_kernel=self.transpose_kernel, precision=self.precision, ) if self.padding == 'CIRCULAR': # For circular padding, we need to identify the size of the final output # ("period") along each spatial dimension, pad each dimension to an # integer number of periods, and wrap the array periodically around each # dimension. Padding should be done in such a way that the start of the # original input data inside the padded array is located at integer # number of periods - otherwise the result would be circularly shifted. # Compute period along each spatial dimension - it's input size scaled # by the stride. scaled_x_dims = [ x_dim * stride for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides) ] # Compute difference between the current size of y and the final output # size, and complement this difference to 2 * period - that gives how # much we need to pad. size_diffs = [ -(y_dim - x_dim) % (2 * x_dim) for y_dim, x_dim in zip(y.shape[1:-1], scaled_x_dims) ] if self.transpose_kernel: # If the kernel is transposed, the "+1" is put on the right to # mirror the regular convolution. If the same kernel parameters are used # as for Conv, this layer then computes the proper transpose convolution. total_pad = [ (size_diff // 2, (size_diff + 1) // 2) for size_diff in size_diffs ] else: # Divide the padding equally between left and right. The choice to put # "+1" on the left (and not on the right) represents a convention for # aligning even-sized kernels. total_pad = [ ((size_diff + 1) // 2, size_diff // 2) for size_diff in size_diffs ] y = jnp.pad(y, [(0, 0)] + total_pad + [(0, 0)]) # Wrap the result periodically around each spatial dimension, # one by one. for i in range(1, y.ndim - 1): y = y.reshape( y.shape[:i] + (-1, scaled_x_dims[i - 1]) + y.shape[i + 1 :] ) y = y.sum(axis=i) if self.use_bias: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) if num_batch_dimensions != 1: output_shape = input_batch_shape + y.shape[1:] y = jnp.reshape(y, output_shape) return y
default_embed_init = initializers.variance_scaling( 1.0, 'fan_in', 'normal', out_axis=0 )
[docs]class Embed(Module): """Embedding Module. A parameterized function from integers [0, n) to d-dimensional vectors. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp >>> layer = nn.Embed(num_embeddings=4, features=3) >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 5), dtype=int)) >>> jax.tree_map(jnp.shape, variables) {'params': {'embedding': (4, 3)}} >>> layer.apply(variables, jnp.ones((5,), dtype=int)).shape (5, 3) >>> layer.apply(variables, jnp.ones((5, 6), dtype=int)).shape (5, 6, 3) Attributes: num_embeddings: number of embeddings. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). param_dtype: the dtype passed to parameter initializers (default: float32). embedding_init: embedding initializer. """ num_embeddings: int features: int dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 embedding_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_embed_init embedding: Array = dataclasses.field(init=False) def setup(self): self.embedding = self.param( 'embedding', self.embedding_init, (self.num_embeddings, self.features), self.param_dtype, )
[docs] def __call__(self, inputs: Array) -> Array: """Embeds the inputs along the last dimension. Args: inputs: input data, all dimensions are considered batch dimensions. Returns: Output which is embedded input data. The output shape follows the input, with an additional `features` dimension appended. """ if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = promote_dtype( self.embedding, dtype=self.dtype, inexact=False ) return jnp.take(embedding, inputs, axis=0)
def attend(self, query: Array) -> Array: """Attend over the embedding using a query array. Args: query: array with last dimension equal the feature depth `features` of the embedding. Returns: An array with final dim `num_embeddings` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ query, embedding = promote_dtype(query, self.embedding, dtype=self.dtype) return jnp.dot(query, embedding.T)