flax.linen.pool#
- flax.linen.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#
Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op. NOTE: Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.
- Parameters
inputs – input data with dimensions (batch, window dims…, features).
init – the initial value for the reduction
reduce_fn – a reduce function of the form (T, T) -> T.
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of n integers, representing the inter-window strides (default: (1, …, 1)).
padding – either the string ‘SAME’, the string ‘VALID’, or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension.
- Returns
The output of the reduction for each window slice.