flax.linen.avg_pool

Contents

flax.linen.avg_pool#

flax.linen.avg_pool(inputs, window_shape, strides=None, padding='VALID', count_include_pad=True)[source]#

Pools the input by taking the average over a window.

Parameters
  • inputs – input data with dimensions (batch, window dims…, features).

  • window_shape – a shape tuple defining the window to reduce over.

  • strides – a sequence of n integers, representing the inter-window strides (default: (1, ..., 1)).

  • padding – either the string 'SAME', the string 'VALID', or a sequence of n (low, high) integer pairs that give the padding to apply before and after each spatial dimension (default: 'VALID').

  • count_include_pad – a boolean whether to include padded tokens in the average calculation (default: True).

Returns

The average for each window slice.