Pooling functions

Pooling functions#

flax.nnx.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

Pools the input by taking the average over a window.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, ..., 1)).

  • padding – either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: 'VALID').

  • count_include_pad – a boolean whether to include padded tokens in the average calculation (default: True).

Returns:

The average for each window slice.

flax.nnx.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

Pools the input by taking the maximum of a window slice.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, ..., 1)).

  • padding – either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: 'VALID').

Returns:

The maximum for each window slice.

flax.nnx.min_pool(inputs, window_shape, strides=None, padding='VALID')[source]#

Pools the input by taking the minimum of a window slice.

Parameters:
  • inputs – Input data with dimensions (batch, window dims…, features).

  • window_shape – A shape tuple defining the window to reduce over.

  • strides – A sequence of n integers, representing the inter-window strides (default: (1, ..., 1)).

  • padding – Either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: 'VALID').

Returns:

The minimum for each window slice.

flax.nnx.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#

Helper function to define pooling functions.

Pooling functions are implemented using the ReduceWindow XLA op.

Note

Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.

Parameters:
  • inputs – input data with dimensions (batch, window dims…, features).

  • init – the initial value for the reduction

  • reduce_fn – a reduce function of the form (T, T) -> T.

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, ..., 1)).

  • padding – either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.

Returns:

The output of the reduction for each window slice.