Ensembling on multiple devices

We show how to train an ensemble of CNNs on the MNIST dataset, where the size of the ensemble is equal to the number of available devices. In short, this change be described as:

  • make a number of functions parallel using jax.pmap,

  • replicate the inputs carefully,

  • make sure the parallel and non-parallel logic interacts correctly.

In this HOWTO we omit some of the code such as imports, the CNN module, and metrics computation, but they can be found in the MNIST example.

Parallel functions

We start by creating a parallel version of get_initial_params, which retrieves the initial parameters of the models. We do this using jax.pmap. The effect of “pmapping” a function is that it will compile the function with XLA (similar to jax.jit), but execute it in parallel on XLA devices (e.g., GPUs/TPUs).

Single-model

Ensemble

@jax.jit
def get_initial_params(key):
  init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
  initial_params = CNN().init(key, init_val)['params']
  return initial_params
@jax.pmap
def get_initial_params(key):
  init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
  initial_params = CNN().init(key, init_val)['params']
  return initial_params

Note that for the single-model code above, we use jax.jit to lazily initialize the model (see Module.init’s documentation for more details). For the ensembling case, jax.pmap will map over the first axis of the provided argument key by default, so we should make sure that we provide one key for each device when we call this function later on.

Next we simply do the same for the functions create_optimizer, train_step, and eval_step. We also make a minor change to eval_model, which ensures the metrics are used correctly in the parallel setting.

Single-model

Ensemble

#
def create_optimizer(params, learning_rate=0.1, beta=0.9):
  optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                 beta=beta)
  optimizer = optimizer_def.create(params)
  return optimizer

@jax.jit
def train_step(optimizer, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'])
  return optimizer, metrics

@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']
@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def create_optimizer(params, learning_rate=0.1, beta=0.9):
  optimizer_def = optim.Momentum(learning_rate=learning_rate,
                                 beta=beta)
  optimizer = optimizer_def.create(params)
  return optimizer

@jax.pmap
def train_step(optimizer, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'])
  return optimizer, metrics

@jax.pmap
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])

def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = metrics
  return summary['loss'], summary['accuracy']

Note that for create_optimizer we also specify that learning_rate and beta are static arguments, which means the concrete values of these arguments will be used, rather than abstract shapes. This is necessary because the provided arguments will be scalar values. For more details see JIT mechanics: tracing and static variables.

Training the Ensemble

Next we transform the train_epoch function.

Single-model

Ensemble

def train_epoch(optimizer, train_ds, rng, batch_size=10):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}

    optimizer, metrics = train_step(optimizer, batch)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)


  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  return optimizer, epoch_metrics_np
def train_epoch(optimizer, train_ds, rng, batch_size=10):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    batch = jax_utils.replicate(batch)
    optimizer, metrics = train_step(optimizer, batch)
    batch_metrics.append(metrics)

  batch_metrics_np = jax.device_get(batch_metrics)
  batch_metrics_np = jax.tree_multimap(lambda *xs: np.array(xs),
                                      *batch_metrics_np)
  epoch_metrics_np = {
         k: np.mean(batch_metrics_np[k], axis=0)
         for k in batch_metrics_np}

  return optimizer, epoch_metrics_np

As can be seen, we do not have to make any changes to the logic around the optimizer. This is because, as we will see below in our training code, the optimizer is replicated already, so when we pass it to train_step, things will just work fine since train_step is pmapped. However, the train dataset is not yet replicated, so we do that here. Since replicating the entire train dataset is too memory intensive we do it at the batch level.

The rest of the changes relate to making sure the batch metrics are stored correctly for all devices. We use jax.tree_multimap to stack all of the metrics from each device into numpy arrays, such that e.g., batch_metrics_np['loss'] has shape (steps_per_epoch, jax.device_count()).

We can now rewrite the actual training logic. This consists of two simple changes: making sure the RNGs are replicate when we pass them to get_initial_params, and replicating the test dataset, which is much smaller than the train dataset so we can do this for the entire dataset directly.

Single-model

Ensemble

train_ds, test_ds = get_datasets()


rng, init_rng = random.split(random.PRNGKey(0))
params = get_initial_params(init_rng)
optimizer = create_optimizer(params, learning_rate=0.1,
                             momentum=0.9)

for epoch in range(num_epochs):
  rng, input_rng = random.split(rng)
  optimizer, _ = train_epoch(optimizer, train_ds, input_rng)
  loss, accuracy = eval_model(optimizer.target, test_ds)

  logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
              epoch, loss, accuracy * 100)
train_ds, test_ds = get_datasets()
test_ds = jax_utils.replicate(test_ds)

rng, init_rng = random.split(random.PRNGKey(0))
params = get_initial_params(random.split(rng,
                            jax.device_count()))
optimizer = create_optimizer(params, 0.1, 0.9)

for epoch in range(num_epochs):
  rng, input_rng = random.split(rng)
  optimizer, _ = train_epoch(optimizer, train_ds, input_rng)
  loss, accuracy = eval_model(optimizer.target, test_ds)

  logging.info('eval epoch: %d, loss: %s, accuracy: %s',
              epoch, loss, accuracy * 100)

Note that create_optimizer is using positional arguments in the ensembling case. This is because we defined those arguments as static broadcasted arguments, and those should be positional rather then keyword arguments.