Scale up Flax Modules on multiple devices with pjit#

This guide shows how to scale up Flax Modules on multiple devices and hosts using JAX’s pjit and flax.linen.spmd.

Flax and pjit#

jax.experimental.pjit provides a way to automatically compile and scale up JAX computations. pjit has the following benefits:

  • pjit has the similar interface of jax.jit and works as a decorator on a function that needs to be compiled.

  • When using pjit, you can write code as if it runs on a single device, and pjit will automatically compile and run it on multiple devices using the Single Program Multi Data (SPMD) paradigm.

  • With pjit you can state how the input and output of your code is partitioned across devices, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.

To learn more, refer to JAX-101 pjit tutorial and JAX in multi-process environments.

Flax provides several functionalities that can help you use pjit on Flax Modules, including:

  1. An interface to specify partitions of your data when defining flax.linen.Module.

  2. Utility functions to generate the partition information that pjit requires to run.

  3. An interface to customize your axis names called “logical axis annotations” to decouple both your Module code and partition plan to experiment with different partition layouts more easily.

Setup#

Install Flax from HEAD:

# Once Flax v0.6.4 is released, use `pip3 install flax`.
! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"

Imports#

Import some necessary dependencies.

Note: This guide uses the --xla_force_host_platform_device_count=8 flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. Check out the JAX-101 pjit tutorial to learn more about emulating a multi-device TPU environment (in which case you should ignore running os.environ).

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import functools
import numpy as np
import jax

from jax import lax, random, numpy as jnp

import flax
from flax import struct, traverse_util, linen as nn
from flax.linen import spmd # Flax Linen SPMD.
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints

import optax # Optax for common losses and optimizers. 
2023-06-06 05:22:13.451503: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Next, import all the pjit-related libraries.

Note: jax.experimental.pjit is still in the experimental package of JAX, so there may be changes in the API in future.

  1. Start a 2x4 device mesh (8 devices)—this is the same as the layout of TPU v3-8.

  2. Annotate each axis with a name. A typical way to annotate axis names is ('data', 'model'), where:

  • 'data': the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.

  • 'model': the mesh dimension used for sharding parameters of the model across devices.

from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils

# Start a device mesh.
device_mesh = mesh_utils.create_device_mesh((2, 4))
print(device_mesh)
# Annotate each axis with a name.
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
mesh
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
 [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh(device_ids=array([[0, 1, 2, 3],
       [4, 5, 6, 7]]), axis_names=('data', 'model'))

Define a layer#

Before defining a model, create an example layer called DotReluDot (by subclassing flax.linen.Module), which creates two parameters W1 and W2 for dot product multiplication, and uses the jax.nn.relu (ReLU) activation function in-between.

To use this layer in pjit efficiently, apply the following APIs to annotate the parameters and intermediate variables correctly:

  1. Use flax.linen.with_partitioning to decorate the initializer function when creating parameters W1 and W2.

  2. Apply pjit.with_sharding_constraint to annotate intermediate variables like y and z to force a particular sharding pattern under pjit when the ideal constraint is known.

  • This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because pjit will figure out the same sharding layout for y and z regardless.

class DotReluDot(nn.Module):
  depth: int
  @nn.compact
  def __call__(self, x):
    W1 = self.param(
        'W1', 
        nn.with_partitioning(nn.initializers.xavier_normal(), (None, 'model')),
        (x.shape[-1], self.depth))

    y = jax.nn.relu(jnp.dot(x, W1))
    # Force a local sharding annotation.
    y = with_sharding_constraint(y, PartitionSpec('data', 'model'))

    W2 = self.param(
        'W2', 
        nn.with_partitioning(nn.initializers.xavier_normal(), ('model', None)),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = with_sharding_constraint(z, PartitionSpec('data', None))

    # Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
    return z, None

Note that device axis names like 'data', 'model' or None are passed into both flax.linen.with_partitioning and pjit_with_sharding_constraint API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.

For example:

  • When you define W1 with shape (x.shape[-1], self.depth) and annotate as (None, 'model'):

    • The first dimension (of length x.shape[-1]) will be replicated across all devices.

    • The second dimension (of length self.depth) will be sharded over the 'model' axis of the device mesh. This means W1 will be sharded 4-way on devices (0, 4), (1, 5), (2, 6) and (3, 7), on this dimension.

  • When you annotate the output z as ('data', None):

    • The first dimension — the batch dimension — will be sharded over the 'data' axis. This means half of the batch will be processed on devices 0-3 (first four devices), and another half on devices 4-7 (the remaining four devices).

    • The second dimension — the data depth dimension — will be replicated across all devices.

Define a model with flax.linen.scan lifted transformation#

This guide uses flax.linen.scan to demonstrate how Flax lifted transforms, such as scan, can work together with JAX pjit.

Having created DotReluDot, define the MLP model (by subclassing flax.linen.Module) as multiple layers of DotReluDot.

To replicate identical layers, you can either use flax.linen.scan, or a for-loop:

  • flax.linen.scan can offer faster compilation times.

  • The for-loop can be faster on runtime.

The code below shows how to apply both methods.

Note: flax.linen.scan has another dimension for the parameters (the dimension over which scan is applied). You need to use the metadata_params argument to annotate the partition of this dimension. Since the parameters inside your DotReluDot (a sub-Module) are already sharded along the model axis, you don’t need to partition multiple layers across the model dimension here, and therefore you should denote it as None.

class MLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(DotReluDot, length=self.num_layers, 
                     variable_axes={"params": 0},
                     split_rngs={"params": True},
                     metadata_params={nn.PARTITION_NAME: None}
                     )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = DotReluDot(self.depth)(x)
    return x

Specify sharding (includes initialization and TrainState creation)#

Next, generate the jax.sharding.PartitionSpec that pjit should receive as annotations of input and output data. PartitionSpec is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to JAX-101: Introduction to pjit.

Specify the input#

For data parallelism, you can shard the batched input x across the data axis by denoting the batch axis as data:

x_spec = PartitionSpec('data', None)  # dimensions: (batch, length)
x_spec
PartitionSpec('data', None)

Generate a PartitionSpec for the output#

Next, generate a PartitionSpec for the output, you need to use some actual output as a reference.

  1. Instantiate a model.

  2. Evaluate model.init abstractly using jax.eval_shape.

  3. Use flax.linen.get_partition_spec to automatically generate the PartitionSpec.

The code below shows how to get the output spec if you use flax.training.train_state to carry out your initialization and training steps, in which case your pjitted function will output a TrainState.

(In a simpler case, people might choose the variable dict as in variables = model.init(k, x) as their pjitted function’s output. That works too.)

# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, True
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = random.PRNGKey(0)

# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)

# A functional way of model initialization.
def init_fn(k, x, model, optimizer):
  variables = model.init(k, x) # Initialize the model.
  state = train_state.TrainState.create( # Create a `TrainState`.
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer)
  return state

with mesh:
  # Create an abstract closure to wrap the function before feeding it in
  # because `jax.eval_shape` only takes pytrees as arguments`.
  abstract_variables = jax.eval_shape(
      functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
# This `state_spec` has the same pytree structure as the output
# of the `init_fn`.
state_spec = nn.get_partition_spec(abstract_variables)
state_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of MLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = True
)>, params=FrozenDict({
    ScanDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
    ScanDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
}), nu=FrozenDict({
    ScanDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
})), EmptyState()))

Apply pjit to compile the code#

Now you can apply JAX pjit to your init_fn in a similar fashion as jax.jit but with two extra arguments: in_axis_resources and out_axis_resources.

You need to add a with mesh: context when running a pjitted function, so that it can refer to mesh (an instance of jax.sharding.Mesh) to allocate data on devices correctly.

pjit_init_fn = pjit(init_fn,
                    static_argnums=(2, 3),
                    in_axis_resources=(PartitionSpec(None), x_spec),  # PRNG key and x
                    out_axis_resources=state_spec
                    )
with mesh:
  initialized_state = pjit_init_fn(k, x, model, optimizer)
jax.tree_map(jnp.shape, initialized_state)
TrainState(step=(), apply_fn=<bound method Module.apply of MLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = True
)>, params=FrozenDict({
    ScanDotReluDot_0: {
        W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
        W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=(), mu=FrozenDict({
    ScanDotReluDot_0: {
        W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
        W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
    },
}), nu=FrozenDict({
    ScanDotReluDot_0: {
        W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
        W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
    },
})), EmptyState()))

Inspect the Module output#

Note that in the output of initialized_state, the params W1 and W2 are of type flax.linen.Partitioned. This is a wrapper around the actual jax.Array that allows Flax to record metadata associated with it. You can access the raw jax.Array by adding .value or running .unbox().

You can also check the underlying jax.sharding of the JAX array, which gives a hint on the way it is partitioned.

print(type(initialized_state.params['ScanDotReluDot_0']['W1']))
print(type(initialized_state.params['ScanDotReluDot_0']['W1'].value))
print(initialized_state.params['ScanDotReluDot_0']['W1'].value.shape)
<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(4, 1024, 1024)
print(initialized_state.params['ScanDotReluDot_0']['W1'].value.sharding)
NamedSharding(mesh={'data': 2, 'model': 4}, spec=PartitionSpec(None, None, 'model'))

You can use jax.tree_map to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays.

diff = jax.tree_map(
    lambda a, b: a - b, 
    initialized_state.params['ScanDotReluDot_0'], initialized_state.params['ScanDotReluDot_0'])
print(jax.tree_map(jnp.shape, diff))
diff_array = diff['W1'].unbox()
print(type(diff_array))
print(diff_array.shape)
FrozenDict({
    W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
    W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
})
<class 'jaxlib.xla_extension.ArrayImpl'>
(4, 1024, 1024)

Apply pjit to the train step and inference#

Now, you create a pjitted training step:

def train_step(state, x):
  # A fake loss function.
  def loss_unrolled(params):
    y = model.apply({'params': params}, x)
    return y.sum()
  grad_fn = jax.grad(loss_unrolled)
  grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state

pjit_step_fn = pjit(train_step,
                    in_axis_resources=(state_spec, x_spec),  # input annotations
                    out_axis_resources=state_spec,           # output annotations
                    )
with mesh:
  new_state = pjit_step_fn(initialized_state, x)

Apply pjit to inference. Note that, similar to jax.jit, you can use a decorator like @functools.partial(pjit, ...) to directly compile your function.

@functools.partial(pjit, in_axis_resources=(state_spec, x_spec), out_axis_resources=x_spec)
def pjit_apply_fn(state, x):
  return state.apply_fn({'params': state.params}, x)

with mesh:
  y = pjit_apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)

Profiling#

If you are running on a TPU pod or a pod slice, you can use a custom block_all utility function, as defined below, to measure the performance:

%%timeit

def block_all(xs):
  jax.tree_map(lambda x: x.block_until_ready(), xs)
  return xs

with mesh:
  new_state = block_all(pjit_step_fn(initialized_state, x))
339 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Logical axis annotation#

JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like 'data' and 'model').

The LogicalDotReluDot and LogicalMLP Module definition below are similar to the Modules you created earlier, except for the following:

  1. All axes are annotated with more concrete, meaningful names, such as 'embed', 'hidden', 'batch' and 'layer'. These names are referred to as logical axis names in Flax. They make the dimensional changes inside model definitions more readable.

  2. flax.linen.spmd.with_logical_partitioning replaces flax.linen.with_partitioning; and flax.linen.spmd.with_logical_constraint replaces pjit.with_sharding_constraint, to recognize the logical axis names.

class LogicalDotReluDot(nn.Module):
  depth: int
  @nn.compact
  def __call__(self, x):
    W1 = self.param(
        'W1', 
        spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('embed', 'hidden')),
        (x.shape[-1], self.depth)) 

    y = jax.nn.relu(jnp.dot(x, W1))
    # Force a local sharding annotation.
    y = spmd.with_logical_constraint(y, ('batch', 'hidden'))

    W2 = self.param(
        'W2', 
        spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('hidden', 'embed')),
        (self.depth, x.shape[-1]))

    z = jnp.dot(y, W2)
    # Force a local sharding annotation.
    z = spmd.with_logical_constraint(z, ('batch', 'embed'))
    return z, None

class LogicalMLP(nn.Module):
  num_layers: int
  depth: int
  use_scan: bool
  @nn.compact
  def __call__(self, x):
    if self.use_scan:
      x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers, 
                    variable_axes={"params": 0},
                    split_rngs={"params": True},
                    metadata_params={nn.PARTITION_NAME: 'layer'}
                    )(self.depth)(x)
    else:
      for i in range(self.num_layers):
        x, _ = DotReluDot(self.depth)(x)
    return x

The LogicalMLP model definition generates a set of PartitionSpec with logical axis names.

Repeat the steps from earlier: instantiate a model, evaluate the init_fn abstractly, and use flax.linen.get_partition_spec to automatically generate the PartitionSpec:

logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
logical_abstract_variables = jax.eval_shape(
    functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)
logical_output_spec = nn.get_partition_spec(logical_abstract_variables)
logical_output_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of LogicalMLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = True
)>, params=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec('layer', 'embed', 'hidden'),
        W2: PartitionSpec('layer', 'hidden', 'embed'),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec('layer', 'embed', 'hidden'),
        W2: PartitionSpec('layer', 'hidden', 'embed'),
    },
}), nu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec('layer', 'embed', 'hidden'),
        W2: PartitionSpec('layer', 'hidden', 'embed'),
    },
})), EmptyState()))

To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis 'data' or 'model'. This rule is a list of (logical_axis_name, device_axis_name) tuples, and jax.linen.spmd.logical_to_mesh will convert them to the spec that pjit accepts.

This allows you to change the rules and try out new partition layouts without modifying the model definition.

# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
         ('hidden', 'model'))

logical_state_spec = spmd.logical_to_mesh(logical_output_spec, rules)
logical_state_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of LogicalMLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = True
)>, params=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
}), nu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: PartitionSpec(None, None, 'model'),
        W2: PartitionSpec(None, 'model', None),
    },
})), EmptyState()))

You can verify that the logical_state_spec here has the same content as state_spec in the previous (“non-logical”) example. This will be the out_axis_resources you specify when creating pjitted functions.

state_spec.params['ScanDotReluDot_0'] == logical_state_spec.params['ScanLogicalDotReluDot_0']
True
logical_pjit_init_fn = pjit(init_fn,
                            static_argnums=(2, 3),
                            in_axis_resources=(PartitionSpec(None), x_spec),  # RNG key and x
                            out_axis_resources=logical_state_spec
                            )
with mesh:
  logical_initialized_state = logical_pjit_init_fn(k, x, logical_model, optimizer)
jax.tree_map(jnp.shape, logical_initialized_state)
TrainState(step=(), apply_fn=<bound method Module.apply of LogicalMLP(
    # attributes
    num_layers = 4
    depth = 1024
    use_scan = True
)>, params=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
        W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
    },
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=(), mu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
        W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
    },
}), nu=FrozenDict({
    ScanLogicalDotReluDot_0: {
        W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
        W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
    },
})), EmptyState()))

When to use device axis / logical axis#

Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model.

If you want a very simple model, or you are very confident of your way of partitioning, defining it with device mesh axis can potentially save you a few extra lines of code of converting the logical naming back to the device naming.

On the other hand, the logical naming helpers are useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.

In really advanced use cases, you may have more complicated sharding patterns that require annotating activation dimension names differently from parameter dimension names. When people wish to have more fine-grained control on manual mesh assignments, directly using device axis names could be more helpful.

Save the data#

You can use flax.training.checkpoints to save the cross-device array, as shown in the Save and load checkpoints guide - Multi-host/multi-process checkpointing. This is especially required if you are running on a multi-host environment (for example, a TPU pod).

Keep in mind that to restore the arrays to the desired partition, you need to provide a sample target pytree that has the same structure and has the desired PartitionSpec in place for each JAX array. The PartitionSpec you use to restore the array doesn’t necessarily need to be the same as the ones you used to store the array.