flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#

Helper function to define pooling functions.

Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.

  • inputs – input data with dimensions (batch, window dims…, features).

  • init – the initial value for the reduction

  • reduce_fn – a reduce function of the form (T, T) -> T.

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).

  • padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.


The output of the reduction for each window slice.