Flax HOWTOs

Flax aims to be a thin library of composable primitives. You compose these primites yourself into a training loop that you write and is fully under your control. You have full freedom to modify the behavior of your training loop, and you should generally not need special library support to implement the modifications you want.

To help you get started, we show some sample diffs, which we call “HOWTOs”. These HOWTOs show common modifications to training loops. For instance, the HOWTO for ensembling learning demonstrates what changes should be made to the standard MNIST example in order to train an ensemble of models on multiple devices.

Note that these HOWTOs do not require special library support, they just demonstate how assembling the JAX and Flax primitives in different ways allow you to make various training loop modifications.

Currently the following HOWTOs are available:

Multi-device data-parallel training

View as a side-by-side diff

--- a/examples/mnist/train.py
+++ b/examples/mnist/train.py
[...]
 from absl import flags
 from absl import logging
 
+import functools
+
+from flax import jax_utils
 from flax import nn
 from flax import optim
 from flax.metrics import tensorboard
[...]
 def create_optimizer(model, learning_rate, beta):
   optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
   optimizer = optimizer_def.create(model)
+  optimizer = jax_utils.replicate(optimizer)
   return optimizer
 
[...]
   return -jnp.mean(jnp.sum(onehot(labels) * logits, axis=-1))
 
+def shard(xs):
+  return jax.tree_map(
+      lambda x: x.reshape((jax.device_count(), -1) + x.shape[1:]), xs)
+
 def compute_metrics(logits, labels):
   loss = cross_entropy_loss(logits, labels)
   accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
[...]
   return metrics
 
-@jax.jit
+@functools.partial(jax.pmap, axis_name='batch')
 def train_step(optimizer, batch):
   """Train for a single step."""
   def loss_fn(model):
[...]
     return loss, logits
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
   (_, logits), grad = grad_fn(optimizer.target)
+  grad = jax.lax.pmean(grad, axis_name='batch')
   optimizer = optimizer.apply_gradient(grad)
   metrics = compute_metrics(logits, batch['label'])
+  metrics = jax.lax.pmean(metrics, axis_name='batch')
   return optimizer, metrics
 
[...]
 
 def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
   """Train for a single epoch."""
+  if batch_size % jax.device_count() > 0:
+    raise ValueError('Batch size must be divisible by the number of devices')
+
   train_ds_size = len(train_ds['image'])
   steps_per_epoch = train_ds_size // batch_size
 
[...]
   batch_metrics = []
   for perm in perms:
     batch = {k: v[perm] for k, v in train_ds.items()}
+    batch = shard(batch)
     optimizer, metrics = train_step(optimizer, batch)
     batch_metrics.append(metrics)
 
[...]
     rng, input_rng = random.split(rng)
     optimizer, train_metrics = train_epoch(
         optimizer, train_ds, batch_size, epoch, input_rng)
-    loss, accuracy = eval_model(optimizer.target, test_ds)
+    model = jax_utils.unreplicate(optimizer.target)  # Fetch from 1st device
+    loss, accuracy = eval_model(model, test_ds)
     logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                  epoch, loss, accuracy * 100)
     summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py
index 429bda4..3ebf534 100644
--- a/examples/mnist/train_test.py
+++ b/examples/mnist/train_test.py
[...]
 import jax
 from jax import random
 
+from flax import jax_utils
+
 # Parse absl flags test_srcdir and test_tmpdir.
 jax.config.parse_flags_with_absl()
 
[...]
     # test single train step.
     optimizer, train_metrics = train.train_step(
         optimizer=optimizer,
-        batch={k: v[:batch_size] for k, v in train_ds.items()})
+        batch=train.shard({k: v[:batch_size] for k, v in train_ds.items()}))
     self.assertLessEqual(train_metrics['loss'], 2.302)
     self.assertGreaterEqual(train_metrics['accuracy'], 0.0625)
 
     # Run eval model.
-    loss, accuracy = train.eval_model(optimizer.target, test_ds)
+    model = jax_utils.unreplicate(optimizer.target) # Fetch from 1st device
+    loss, accuracy = train.eval_model(model, test_ds)
     self.assertLess(loss, 2.252)
     self.assertGreater(accuracy, 0.2597)
 

Ensembling on multiple devices

View as a side-by-side diff

--- a/examples/mnist/train.py
+++ b/examples/mnist/train.py
[...]
 from absl import flags
 from absl import logging
 
+import functools
+
+from flax import jax_utils
 from flax import nn
 from flax import optim
 from flax.metrics import tensorboard
[...]
     return x
 
+@jax.pmap
 def create_model(key):
   _, initial_params = CNN.init_by_shape(key, [((1, 28, 28, 1), jnp.float32)])
   model = nn.Model(CNN, initial_params)
   return model
 
+@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
 def create_optimizer(model, learning_rate, beta):
   optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
   optimizer = optimizer_def.create(model)
[...]
   return metrics
 
-@jax.jit
+@jax.pmap
 def train_step(optimizer, batch):
   """Train for a single step."""
   def loss_fn(model):
[...]
   return optimizer, metrics
 
-@jax.jit
+@jax.pmap
 def eval_step(model, batch):
   logits = model(batch['image'])
   return compute_metrics(logits, batch['label'])
[...]
   batch_metrics = []
   for perm in perms:
     batch = {k: v[perm] for k, v in train_ds.items()}
+    batch = jax_utils.replicate(batch)
     optimizer, metrics = train_step(optimizer, batch)
     batch_metrics.append(metrics)
 
   # compute mean of metrics across each batch in epoch.
   batch_metrics_np = jax.device_get(batch_metrics)
+  # stack all of the metrics from each devioce into
+  # numpy arrays directly on a dict, such that, e.g.
+  # `batch_metrics_np['loss']` has shape (jax.device_count(), )
+  batch_metrics_np = jax.tree_multimap(lambda *xs: onp.array(xs),
+                                       *batch_metrics_np)
   epoch_metrics_np = {
-      k: onp.mean([metrics[k] for metrics in batch_metrics_np])
-      for k in batch_metrics_np[0]}
-
-  logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
+      k: onp.mean(batch_metrics_np[k], axis=0)
+      for k in batch_metrics_np
+  }
+  # `epoch_metrics_np` now contains 1D arrays of length `jax.device_count()`
+  logging.info('train epoch: %d, loss: %s, accuracy: %s', epoch,
                epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)
 
   return optimizer, epoch_metrics_np
[...]
 def eval_model(model, test_ds):
   metrics = eval_step(model, test_ds)
   metrics = jax.device_get(metrics)
-  summary = jax.tree_map(lambda x: x.item(), metrics)
+  summary = metrics
   return summary['loss'], summary['accuracy']
 
[...]
 
   summary_writer = tensorboard.SummaryWriter(model_dir)
 
-  rng, init_rng = random.split(rng)
+  rng, init_rng = random.split(rng, jax.device_count())
   model = create_model(init_rng)
   optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.momentum)
 
+  test_ds = jax_utils.replicate(test_ds)
+
   for epoch in range(1, num_epochs + 1):
     rng, input_rng = random.split(rng)
     optimizer, train_metrics = train_epoch(
         optimizer, train_ds, batch_size, epoch, input_rng)
     loss, accuracy = eval_model(optimizer.target, test_ds)
-    logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
+    # `loss` and `accuracy` are now 1D arrays of length `jax.device_count()`
+    logging.info('eval epoch: %d, loss: %s, accuracy: %s',
                  epoch, loss, accuracy * 100)
     summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
     summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py
index 429bda4..c1405d0 100644
--- a/examples/mnist/train_test.py
+++ b/examples/mnist/train_test.py
[...]
 import jax
 from jax import random
 
+from flax import jax_utils
+
 # Parse absl flags test_srcdir and test_tmpdir.
 jax.config.parse_flags_with_absl()
 
[...]
   def test_single_train_step(self):
     train_ds, test_ds = train.get_datasets()
     batch_size = 32
-    model = train.create_model(random.PRNGKey(0))
+    model = train.create_model(random.split(random.PRNGKey(0),
+                                            jax.device_count()))
+    test_ds = jax_utils.replicate(test_ds)
     optimizer = train.create_optimizer(model, 0.1, 0.9)
 
     # test single train step.
+    batch = jax_utils.replicate({k: v[:batch_size] for k, v in train_ds.items()})
     optimizer, train_metrics = train.train_step(
         optimizer=optimizer,
-        batch={k: v[:batch_size] for k, v in train_ds.items()})
-    self.assertLessEqual(train_metrics['loss'], 2.302)
-    self.assertGreaterEqual(train_metrics['accuracy'], 0.0625)
+        batch=batch)
+    self.assertLessEqual(train_metrics['loss'], 2.34)
+    self.assertGreaterEqual(train_metrics['accuracy'], 0.03125)
 
     # Run eval model.
     loss, accuracy = train.eval_model(optimizer.target, test_ds)
-    self.assertLess(loss, 2.252)
-    self.assertGreater(accuracy, 0.2597)
+    self.assertLessEqual(train_metrics['loss'], 2.34)
+    self.assertGreaterEqual(train_metrics['accuracy'], 0.03125)
 
 if __name__ == '__main__':
   absltest.main()

Polyak averaging

View as a side-by-side diff

--- a/examples/mnist/train.py
+++ b/examples/mnist/train.py
[...]
 
 @jax.jit
-def train_step(optimizer, batch):
+def train_step(optimizer, params_ema, batch):
   """Train for a single step."""
   def loss_fn(model):
     logits = model(batch['image'])
[...]
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
   (_, logits), grad = grad_fn(optimizer.target)
   optimizer = optimizer.apply_gradient(grad)
+  params_ema = jax.tree_multimap(
+      lambda p_ema, p: p_ema * 0.99 + p * 0.01,
+      params_ema, optimizer.target.params)
   metrics = compute_metrics(logits, batch['label'])
-  return optimizer, metrics
+  metrics = compute_metrics(logits, batch['label'])
+  return optimizer, params_ema, metrics
 
 @jax.jit
[...]
   return compute_metrics(logits, batch['label'])
 
-def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
+def train_epoch(optimizer, params_ema, train_ds, batch_size, epoch, rng):
   """Train for a single epoch."""
   train_ds_size = len(train_ds['image'])
   steps_per_epoch = train_ds_size // batch_size
[...]
   batch_metrics = []
   for perm in perms:
     batch = {k: v[perm] for k, v in train_ds.items()}
-    optimizer, metrics = train_step(optimizer, batch)
+    optimizer, params_ema, metrics = train_step(optimizer, batch)
     batch_metrics.append(metrics)
 
   # compute mean of metrics across each batch in epoch.
[...]
   rng, init_rng = random.split(rng)
   model = create_model(init_rng)
   optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.momentum)
+  params_ema = model.params
 
   for epoch in range(1, num_epochs + 1):
     rng, input_rng = random.split(rng)
     optimizer, train_metrics = train_epoch(
-        optimizer, train_ds, batch_size, epoch, input_rng)
+        optimizer, params_ema, train_ds, batch_size, epoch, input_rng)
     loss, accuracy = eval_model(optimizer.target, test_ds)
     logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                  epoch, loss, accuracy * 100)
+    model_ema = optimizer.target.replace(params=params_ema)
+    polyak_loss, polyak_accuracy = eval_model(model_ema, test_ds)
+    logging.info('polyak eval epoch: %d, loss: %.4f, accuracy: %.2f',
+                 epoch, polyak_loss, polyak_accuracy * 100)
     summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
     summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
     summary_writer.scalar('eval_loss', loss, epoch)
diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py
index 429bda4..b441168 100644
--- a/examples/mnist/train_test.py
+++ b/examples/mnist/train_test.py
[...]
     optimizer = train.create_optimizer(model, 0.1, 0.9)
 
     # test single train step.
-    optimizer, train_metrics = train.train_step(
+    optimizer, params_ema, train_metrics = train.train_step(
         optimizer=optimizer,
+        params_ema=model.params,
         batch={k: v[:batch_size] for k, v in train_ds.items()})
     self.assertLessEqual(train_metrics['loss'], 2.302)
     self.assertGreaterEqual(train_metrics['accuracy'], 0.0625)

Scheduled Sampling

View as a side-by-side diff

--- a/examples/seq2seq/train.py
+++ b/examples/seq2seq/train.py
[...]
 class Decoder(nn.Module):
   """LSTM decoder."""
 
-  def apply(self, init_state, inputs, teacher_force=False):
+  def apply(self, init_state, inputs, sample_probability=0.0):
     # inputs.shape = (batch_size, seq_length, vocab_size).
-    vocab_size = inputs.shape[2]
+    batch_size, _, vocab_size = inputs.shape
     lstm_cell = nn.LSTMCell.partial(name='lstm')
     projection = nn.Dense.partial(features=vocab_size, name='projection')
 
     def decode_step_fn(carry, x):
       rng, lstm_state, last_prediction = carry
-      carry_rng, categorical_rng = jax.random.split(rng, 2)
-      if not teacher_force:
-        x = last_prediction
+      carry_rng, bernoulli_rng, categorical_rng = jax.random.split(rng, 3)
+      x = jnp.where(
+          jax.random.bernoulli(
+              bernoulli_rng, p=sample_probability, shape=(batch_size, 1)),
+          last_prediction, x)
       lstm_state, y = lstm_cell(lstm_state, x)
       logits = projection(y)
       predicted_tokens = jax.random.categorical(categorical_rng, logits)
[...]
   def apply(self,
             encoder_inputs,
             decoder_inputs,
-            teacher_force=True,
+            sample_probability=0.0,
             eos_id=1,
             hidden_size=512):
     """Run the seq2seq model.
[...]
         `[batch_size, max(encoder_input_lengths), vocab_size]`.
       decoder_inputs: padded batch of expected decoded sequences for teacher
         forcing, shaped `[batch_size, max(decoder_inputs_length), vocab_size]`.
-        When sampling (i.e., `teacher_force = False`), the initial time step is
-        forced into the model and samples are used for the following inputs. The
-        second dimension of this tensor determines how many steps will be
-        decoded, regardless of the value of `teacher_force`.
-      teacher_force: bool, whether to use `decoder_inputs` as input to the
-        decoder at every step. If False, only the first input is used, followed
-        by samples taken from the previous output logits.
-      eos_id: int, the token signaling when the end of a sequence is reached.
+        The initial time step is forced into the model and samples are used for
+        the following inputs as determined by `sample_probability`. The second
+        dimension of this tensor determines how many steps will be
+        decoded, regardless of the value of `sample_probability`.
+      sample_probability: float in [0, 1], the probability of using the previous
+        sample as the next input instead of the value in `decoder_inputs` when
+        sampling. A value of 0 is equivalent to teacher forcing and 1 indicates
+        full sampling.
+      eos_id: int, the token signalling when the end of a sequence is reached.
       hidden_size: int, the number of hidden dimensions in the encoder and
         decoder LSTMs.
     Returns:
[...]
     logits, predictions = decoder(
         init_decoder_state,
         decoder_inputs[:, :-1],
-        teacher_force=teacher_force)
+        sample_probability=sample_probability)
 
     return logits, predictions
 
[...]
 
 @jax.jit
-def train_step(optimizer, batch, rng):
+def train_step(optimizer, batch, sample_probability, rng):
   """Train one step."""
   labels = batch['answer'][:, 1:]  # remove '=' start token
 
   def loss_fn(model):
     """Compute cross-entropy loss."""
     with nn.stochastic(rng):
-      logits, _ = model(batch['query'], batch['answer'])
+      logits, _ = model(
+          batch['query'], batch['answer'],
+          sample_probability=sample_probability)
     loss = cross_entropy_loss(logits, labels, get_sequence_lengths(labels))
     return loss, logits
   grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
[...]
   init_decoder_inputs = jnp.tile(init_decoder_input,
                                  (inputs.shape[0], get_max_output_len(), 1))
   with nn.stochastic(rng):
-    _, predictions = model(inputs, init_decoder_inputs, teacher_force=False)
+    _, predictions = model(inputs, init_decoder_inputs, sample_probability=1.0)
   return predictions
 
[...]
 
 def train_model():
   """Train for a fixed number of steps and decode during training."""
+  def inv_sigmoid(i, k=500):
+    return k / (k + jnp.exp(i / k))
+
   with nn.stochastic(jax.random.PRNGKey(0)):
     model = create_model()
     optimizer = create_optimizer(model, FLAGS.learning_rate)
     for step in range(FLAGS.num_train_steps):
       batch = get_batch(FLAGS.batch_size)
-      optimizer, metrics = train_step(optimizer, batch, nn.make_rng())
+      sample_probability = 1 - inv_sigmoid(step)
+      optimizer, metrics = train_step(
+          optimizer, batch, sample_probability, nn.make_rng())
       if step % FLAGS.decode_frequency == 0:
-        logging.info('train step: %d, loss: %.4f, accuracy: %.2f', step,
-                     metrics['loss'], metrics['accuracy'] * 100)
+        logging.info(
+            'train step: %d, sample prob: %.4f, loss: %.4f, accuracy: %.2f',
+            step, sample_probability, metrics['loss'],
+            metrics['accuracy'] * 100)
         decode_batch(optimizer.target, 5)
   return optimizer.target
 
diff --git a/examples/seq2seq/train_test.py b/examples/seq2seq/train_test.py
index 7efc1b1..122659e 100644
--- a/examples/seq2seq/train_test.py
+++ b/examples/seq2seq/train_test.py
[...]
       model = train.create_model()
       optimizer = train.create_optimizer(model, 0.003)
       optimizer, train_metrics = train.train_step(
-          optimizer, batch, nn.make_rng())
+          optimizer, batch, 0.5, nn.make_rng())
 
     self.assertLessEqual(train_metrics['loss'], 5)
     self.assertGreaterEqual(train_metrics['accuracy'], 0)

How do HOWTOs work?

Read the HOWTOs HOWTO to learn how we maintain HOWTOs.