Pooling functions#
- flax.nnx.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#
Pools the input by taking the average over a window.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension (default:'VALID').count_include_pad – a boolean whether to include padded tokens in the average calculation (default:
True).
- Returns:
The average for each window slice.
- flax.nnx.max_pool(inputs, window_shape, strides=None, padding='VALID')[source]#
Pools the input by taking the maximum of a window slice.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension (default:'VALID').
- Returns:
The maximum for each window slice.
- flax.nnx.min_pool(inputs, window_shape, strides=None, padding='VALID')[source]#
Pools the input by taking the minimum of a window slice.
- Parameters:
inputs – Input data with dimensions (batch, window dims…, features).
window_shape – A shape tuple defining the window to reduce over.
strides – A sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – Either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension (default:'VALID').
- Returns:
The minimum for each window slice.
- flax.nnx.pool(inputs, init, reduce_fn, window_shape, strides, padding)[source]#
Helper function to define pooling functions.
Pooling functions are implemented using the ReduceWindow XLA op.
Note
Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.
- Parameters:
inputs – input data with dimensions (batch, window dims…, features).
init – the initial value for the reduction
reduce_fn – a reduce function of the form
(T, T) -> T.window_shape – a shape tuple defining the window to reduce over.
strides – a sequence of
nintegers, representing the inter-window strides (default:(1, ..., 1)).padding – either the string
'SAME', the string'VALID', or a sequence ofn(low, high)integer pairs that give the padding to apply before and after each spatial dimension.
- Returns:
The output of the reduction for each window slice.