Batch normalization#

In this guide, you will learn how to apply batch normalization using flax.linen.BatchNorm.

Batch normalization is a regularization technique used to speed up training and improve convergence. During training, it computes running averages over feature dimensions. This adds a new form of non-differentiable state that must be handled appropriately.

Throughout the guide, you will be able to compare code examples with and without Flax BatchNorm.

Defining the model with BatchNorm#

In Flax, BatchNorm is a flax.linen.Module that exhibits different runtime behavior between training and inference. You explicitly specify it via the use_running_average argument, as demonstrated below.

A common pattern is to accept a train (training) argument in the parent Flax Module, and use it to define BatchNorm’s use_running_average argument.

Note: In other machine learning frameworks, like PyTorch or TensorFlow (Keras), this is specified via a mutable state or a call flag (for example, in torch.nn.Module.eval or tf.keras.Model by setting the training flag).

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=4)(x)

    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x
class MLP(nn.Module):
  @nn.compact
  def __call__(self, x, train: bool):
    x = nn.Dense(features=4)(x)
    x = nn.BatchNorm(use_running_average=not train)(x)
    x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x

Once you create your model, initialize it by calling flax.linen.init() to get the variables structure. Here, the main difference between the code without BatchNorm and with BatchNorm is that the train argument must be provided.

The batch_stats collection#

In addition to the params collection, BatchNorm also adds a batch_stats collection that contains the running average of the batch statistics.

Note: You can learn more in the flax.linen variables API documentation.

The batch_stats collection must be extracted from the variables for later use.

mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x)
params = variables['params']


jax.tree_util.tree_map(jnp.shape, variables)
mlp = MLP()
x = jnp.ones((1, 3))
variables = mlp.init(jax.random.key(0), x, train=False)
params = variables['params']
batch_stats = variables['batch_stats']

jax.tree_util.tree_map(jnp.shape, variables)

Flax BatchNorm adds a total of 4 variables: mean and var that live in the batch_stats collection, and scale and bias that live in the params collection.

FrozenDict({
  'params': {
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})
FrozenDict({
  'batch_stats': {
    'BatchNorm_0': {
        'mean': (4,),
        'var': (4,),
    },
  },
  'params': {
    'BatchNorm_0': {
        'bias': (4,),
        'scale': (4,),
    },
    'Dense_0': {
        'bias': (4,),
        'kernel': (3, 4),
    },
    'Dense_1': {
        'bias': (1,),
        'kernel': (4, 1),
    },
  },
})

Modifying flax.linen.apply#

When using flax.linen.apply to run your model with the train==True argument (that is, you have use_running_average==False in the call to BatchNorm), you need to consider the following:

  • batch_stats must be passed as an input variable.

  • The batch_stats collection needs to be marked as mutable by setting mutable=['batch_stats'].

  • The mutated variables are returned as a second output. The updated batch_stats must be extracted from here.

y = mlp.apply(
  {'params': params},
  x,
)
...
y, updates = mlp.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  train=True, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

Training and evaluation#

When integrating models that use BatchNorm into a training loop, the main challenge is handling the additional batch_stats state. To do this, you need to:

from flax.training import train_state


state = train_state.TrainState.create(
  apply_fn=mlp.apply,
  params=params,

  tx=optax.adam(1e-3),
)
from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: Any

state = TrainState.create(
  apply_fn=mlp.apply,
  params=params,
  batch_stats=batch_stats,
  tx=optax.adam(1e-3),
)

In addition, update your train_step function to reflect these changes:

  • Pass all new parameters to flax.linen.apply (as previously discussed).

  • The updates to the batch_stats must be propagated out of the loss_fn.

  • The batch_stats from the TrainState must be updated.

@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = state.apply_fn(
      {'params': params},
      x=batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)

  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def train_step(state: TrainState, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=batch['image'], train=True, mutable=['batch_stats'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label'])
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics

The eval_step is much simpler. Because batch_stats is not mutable, no updates need to be propagated. Make sure you pass the batch_stats to flax.linen.apply, and the train argument is set to False:

@jax.jit
def eval_step(state: TrainState, batch):
  """Train for a single step."""
  logits = state.apply_fn(
    {'params': params},
    x=batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics
@jax.jit
def eval_step(state: TrainState, batch):
  """Train for a single step."""
  logits = state.apply_fn(
    {'params': params, 'batch_stats': state.batch_stats},
    x=batch['image'], train=False)
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label'])
  metrics = {
    'loss': loss,
      'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label']),
  }
  return state, metrics