Managing Parameters and State

Managing Parameters and State#

We will show you how to…

  • manage the variables from initialization to updates.

  • split and re-assemble parameters and state.

  • use vmap with batch-dependant state.

class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats', 'mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros, x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      mean.value = (self.momentum * mean.value +
                    (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True))
    return mean.value + bias

This example model is a minimal example that contains both parameters (declared with self.param) and state variables (declared with self.variable).

The tricky part with initialization here is that we need to split the state variables and the parameters we’re going to optimize for.

First we define update_step as follows (with a dummy loss that should be replaced with yours):

def update_step(apply_fn, x, opt_state, params, state):
  def loss(params):
    y, updated_state = apply_fn({'params': params, **state},
                                x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum() # Replace with your loss here.
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(params)
  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state

Then we can write the actual training code.

model = BiasAdderWithRunningMean()
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)

vmap accross the batch dimension#

When using vmap and managing state that depends on the batch dimension, for example when using BatchNorm, the setup above must be modified slightly. This is because any layer whose state depends on the batch dimension is not strictly vectorizable. In the case of BatchNorm, lax.pmean() must be used to average the statistics over the batch dimension so that the state is in sync for each item in the batch.

This requires two small changes. Firstly, we need to name the batch axis in our model definition. Here, this is done by specifying the axis_name argument of BatchNorm. In your own code this might require specifying the axis_name argument of lax.pmean() directly.

class MLP(nn.Module):
  hidden_size: int
  out_size: int

  @nn.compact
  def __call__(self, x, train=False):
    norm = partial(
        nn.BatchNorm,
        use_running_average=not train,
        momentum=0.9,
        epsilon=1e-5,
        axis_name="batch", # Name batch dim
    )

    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_size)(x)
    x = norm()(x)
    x = nn.relu(x)
    y = nn.Dense(self.out_size)(x)

    return y

Secondly, we need to specify the same name when calling vmap in our training code:

def update_step(apply_fn, x_batch, y_batch, opt_state, params, state):

  def batch_loss(params):
    def loss_fn(x, y):
      pred, updated_state = apply_fn(
        {'params': params, **state},
        x, mutable=list(state.keys())
      )
      return (pred - y) ** 2, updated_state

    loss, updated_state = jax.vmap(
      loss_fn, out_axes=(0, None),  # Do not vmap `updated_state`.
      axis_name='batch'  # Name batch dim
    )(x_batch, y_batch)  # vmap only `x`, `y`, but not `state`.
    return jnp.mean(loss), updated_state

  (loss, updated_state), grads = jax.value_and_grad(
    batch_loss, has_aux=True
  )(params)

  updates, opt_state = tx.update(grads, opt_state)  # Defined below.
  params = optax.apply_updates(params, updates)
  return opt_state, params, updated_state, loss

Note that we also need to specify that the model state does not have a batch dimension. Now we are able to train the model:

model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.key(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
del variables  # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state, loss = update_step(
      model.apply, X, Y, opt_state, params, state)