Scale up Flax Modules on multiple devices with pjit
#
This guide shows how to scale up Flax Modules on multiple devices and hosts using JAX’s pjit
and flax.linen.spmd
.
Flax and pjit
#
jax.experimental.pjit
provides a way to automatically compile and scale up JAX computations. pjit
has the following benefits:
pjit
has the similar interface ofjax.jit
and works as a decorator on a function that needs to be compiled.When using
pjit
, you can write code as if it runs on a single device, andpjit
will automatically compile and run it on multiple devices using the Single Program Multi Data (SPMD) paradigm.With
pjit
you can state how the input and output of your code is partitioned across devices, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.
To learn more, refer to JAX-101 pjit tutorial and JAX in multi-process environments.
Flax provides several functionalities that can help you use pjit
on Flax Modules, including:
An interface to specify partitions of your data when defining
flax.linen.Module
.Utility functions to generate the partition information that
pjit
requires to run.An interface to customize your axis names called “logical axis annotations” to decouple both your Module code and partition plan to experiment with different partition layouts more easily.
Setup#
Install Flax from HEAD:
# Once Flax v0.6.4 is released, use `pip3 install flax`.
! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"
Imports#
Import some necessary dependencies.
Note: This guide uses the --xla_force_host_platform_device_count=8
flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. Check out the JAX-101 pjit tutorial to learn more about emulating a multi-device TPU environment (in which case you should ignore running os.environ
).
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
import functools
import numpy as np
import jax
from jax import lax, random, numpy as jnp
import flax
from flax import struct, traverse_util, linen as nn
from flax.linen import spmd # Flax Linen SPMD.
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints
import optax # Optax for common losses and optimizers.
2023-06-06 05:22:13.451503: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Next, import all the pjit
-related libraries.
Note:
jax.experimental.pjit
is still in the experimental package of JAX, so there may be changes in the API in future.
Start a 2x4 device mesh (8 devices)—this is the same as the layout of TPU v3-8.
Annotate each axis with a name. A typical way to annotate axis names is
('data', 'model')
, where:
'data'
: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.'model'
: the mesh dimension used for sharding parameters of the model across devices.
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.sharding import Mesh, PartitionSpec
from jax.experimental import mesh_utils
# Start a device mesh.
device_mesh = mesh_utils.create_device_mesh((2, 4))
print(device_mesh)
# Annotate each axis with a name.
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
mesh
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)]
[CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]]
Mesh(device_ids=array([[0, 1, 2, 3],
[4, 5, 6, 7]]), axis_names=('data', 'model'))
Define a layer#
Before defining a model, create an example layer called DotReluDot
(by subclassing flax.linen.Module
), which creates two parameters W1
and W2
for dot product multiplication, and uses the jax.nn.relu
(ReLU) activation function in-between.
To use this layer in pjit
efficiently, apply the following APIs to annotate the parameters and intermediate variables correctly:
Use
flax.linen.with_partitioning
to decorate the initializer function when creating parametersW1
andW2
.Apply
pjit.with_sharding_constraint
to annotate intermediate variables likey
andz
to force a particular sharding pattern underpjit
when the ideal constraint is known.
This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because
pjit
will figure out the same sharding layout fory
andz
regardless.
class DotReluDot(nn.Module):
depth: int
@nn.compact
def __call__(self, x):
W1 = self.param(
'W1',
nn.with_partitioning(nn.initializers.xavier_normal(), (None, 'model')),
(x.shape[-1], self.depth))
y = jax.nn.relu(jnp.dot(x, W1))
# Force a local sharding annotation.
y = with_sharding_constraint(y, PartitionSpec('data', 'model'))
W2 = self.param(
'W2',
nn.with_partitioning(nn.initializers.xavier_normal(), ('model', None)),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = with_sharding_constraint(z, PartitionSpec('data', None))
# Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.
return z, None
Note that device axis names like 'data'
, 'model'
or None
are passed into both flax.linen.with_partitioning
and pjit_with_sharding_constraint
API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.
For example:
When you define
W1
with shape(x.shape[-1], self.depth)
and annotate as(None, 'model')
:The first dimension (of length
x.shape[-1]
) will be replicated across all devices.The second dimension (of length
self.depth
) will be sharded over the'model'
axis of the device mesh. This meansW1
will be sharded 4-way on devices(0, 4)
,(1, 5)
,(2, 6)
and(3, 7)
, on this dimension.
When you annotate the output
z
as('data', None)
:The first dimension — the batch dimension — will be sharded over the
'data'
axis. This means half of the batch will be processed on devices0-3
(first four devices), and another half on devices4-7
(the remaining four devices).The second dimension — the data depth dimension — will be replicated across all devices.
Define a model with flax.linen.scan
lifted transformation#
This guide uses flax.linen.scan
to demonstrate how Flax lifted transforms, such as scan
, can work together with JAX pjit
.
Having created DotReluDot
, define the MLP
model (by subclassing flax.linen.Module
) as multiple layers of DotReluDot
.
To replicate identical layers, you can either use flax.linen.scan
, or a for-loop:
flax.linen.scan
can offer faster compilation times.The for-loop can be faster on runtime.
The code below shows how to apply both methods.
Note: flax.linen.scan
has another dimension for the parameters (the dimension over which scan
is applied). You need to use the metadata_params
argument to annotate the partition of this dimension. Since the parameters inside your DotReluDot
(a sub-Module
) are already sharded along the model
axis, you don’t need to partition multiple layers across the model
dimension here, and therefore you should denote it as None
.
class MLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(DotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: None}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x
Specify sharding (includes initialization and TrainState
creation)#
Next, generate the jax.sharding.PartitionSpec
that pjit
should receive as annotations of input and output data. PartitionSpec
is a tuple of 2 axes (in a 2x4 mesh). To learn more, refer to JAX-101: Introduction to pjit
.
Specify the input#
For data parallelism, you can shard the batched input x
across the data
axis by denoting the batch axis as data
:
x_spec = PartitionSpec('data', None) # dimensions: (batch, length)
x_spec
PartitionSpec('data', None)
Generate a PartitionSpec
for the output#
Next, generate a PartitionSpec
for the output, you need to use some actual output as a reference.
Instantiate a model.
Evaluate
model.init
abstractly usingjax.eval_shape
.Use
flax.linen.get_partition_spec
to automatically generate thePartitionSpec
.
The code below shows how to get the output spec if you use flax.training.train_state
to carry out your initialization and training steps, in which case your pjit
ted function will output a TrainState
.
(In a simpler case, people might choose the variable dict as in variables = model.init(k, x)
as their pjit
ted function’s output. That works too.)
# MLP hyperparameters.
BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, True
# Create fake inputs.
x = jnp.ones((BATCH, DEPTH))
# Initialize a PRNG key.
k = random.PRNGKey(0)
# Create an Optax optimizer.
optimizer = optax.adam(learning_rate=0.001)
# Instantiate the model.
model = MLP(LAYERS, DEPTH, USE_SCAN)
# A functional way of model initialization.
def init_fn(k, x, model, optimizer):
variables = model.init(k, x) # Initialize the model.
state = train_state.TrainState.create( # Create a `TrainState`.
apply_fn=model.apply,
params=variables['params'],
tx=optimizer)
return state
with mesh:
# Create an abstract closure to wrap the function before feeding it in
# because `jax.eval_shape` only takes pytrees as arguments`.
abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=model, optimizer=optimizer), k, x)
# This `state_spec` has the same pytree structure as the output
# of the `init_fn`.
state_spec = nn.get_partition_spec(abstract_variables)
state_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of MLP(
# attributes
num_layers = 4
depth = 1024
use_scan = True
)>, params=FrozenDict({
ScanDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
ScanDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
}), nu=FrozenDict({
ScanDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
})), EmptyState()))
Apply pjit
to compile the code#
Now you can apply JAX pjit
to your init_fn
in a similar fashion as jax.jit
but with two extra arguments: in_axis_resources
and out_axis_resources
.
You need to add a with mesh:
context when running a pjit
ted function, so that it can refer to mesh
(an instance of jax.sharding.Mesh
) to allocate data on devices correctly.
pjit_init_fn = pjit(init_fn,
static_argnums=(2, 3),
in_axis_resources=(PartitionSpec(None), x_spec), # PRNG key and x
out_axis_resources=state_spec
)
with mesh:
initialized_state = pjit_init_fn(k, x, model, optimizer)
jax.tree_map(jnp.shape, initialized_state)
TrainState(step=(), apply_fn=<bound method Module.apply of MLP(
# attributes
num_layers = 4
depth = 1024
use_scan = True
)>, params=FrozenDict({
ScanDotReluDot_0: {
W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
},
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=(), mu=FrozenDict({
ScanDotReluDot_0: {
W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
},
}), nu=FrozenDict({
ScanDotReluDot_0: {
W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
},
})), EmptyState()))
Inspect the Module output#
Note that in the output of initialized_state
, the params
W1
and W2
are of type flax.linen.Partitioned
. This is a wrapper around the actual jax.Array
that allows Flax to record metadata associated with it. You can access the raw jax.Array
by adding .value
or running .unbox()
.
You can also check the underlying jax.sharding
of the JAX array, which gives a hint on the way it is partitioned.
print(type(initialized_state.params['ScanDotReluDot_0']['W1']))
print(type(initialized_state.params['ScanDotReluDot_0']['W1'].value))
print(initialized_state.params['ScanDotReluDot_0']['W1'].value.shape)
<class 'flax.core.meta.Partitioned'>
<class 'jaxlib.xla_extension.ArrayImpl'>
(4, 1024, 1024)
print(initialized_state.params['ScanDotReluDot_0']['W1'].value.sharding)
NamedSharding(mesh={'data': 2, 'model': 4}, spec=PartitionSpec(None, None, 'model'))
You can use jax.tree_map
to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays.
diff = jax.tree_map(
lambda a, b: a - b,
initialized_state.params['ScanDotReluDot_0'], initialized_state.params['ScanDotReluDot_0'])
print(jax.tree_map(jnp.shape, diff))
diff_array = diff['W1'].unbox()
print(type(diff_array))
print(diff_array.shape)
FrozenDict({
W1: Partitioned(value=(4, 1024, 1024), names=(None, None, 'model'), mesh=None),
W2: Partitioned(value=(4, 1024, 1024), names=(None, 'model', None), mesh=None),
})
<class 'jaxlib.xla_extension.ArrayImpl'>
(4, 1024, 1024)
Apply pjit
to the train step and inference#
Now, you create a pjit
ted training step:
def train_step(state, x):
# A fake loss function.
def loss_unrolled(params):
y = model.apply({'params': params}, x)
return y.sum()
grad_fn = jax.grad(loss_unrolled)
grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state
pjit_step_fn = pjit(train_step,
in_axis_resources=(state_spec, x_spec), # input annotations
out_axis_resources=state_spec, # output annotations
)
with mesh:
new_state = pjit_step_fn(initialized_state, x)
Apply pjit
to inference. Note that, similar to jax.jit
, you can use a decorator like @functools.partial(pjit, ...)
to directly compile your function.
@functools.partial(pjit, in_axis_resources=(state_spec, x_spec), out_axis_resources=x_spec)
def pjit_apply_fn(state, x):
return state.apply_fn({'params': state.params}, x)
with mesh:
y = pjit_apply_fn(new_state, x)
print(type(y))
print(y.dtype)
print(y.shape)
<class 'jaxlib.xla_extension.ArrayImpl'>
float32
(8, 1024)
Profiling#
If you are running on a TPU pod or a pod slice, you can use a custom block_all
utility function, as defined below, to measure the performance:
%%timeit
def block_all(xs):
jax.tree_map(lambda x: x.block_until_ready(), xs)
return xs
with mesh:
new_state = block_all(pjit_step_fn(initialized_state, x))
339 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Logical axis annotation#
JAX auto SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like 'data'
and 'model'
).
The LogicalDotReluDot
and LogicalMLP
Module definition below are similar to the Modules you created earlier, except for the following:
All axes are annotated with more concrete, meaningful names, such as
'embed'
,'hidden'
,'batch'
and'layer'
. These names are referred to as logical axis names in Flax. They make the dimensional changes inside model definitions more readable.flax.linen.spmd.with_logical_partitioning
replacesflax.linen.with_partitioning
; andflax.linen.spmd.with_logical_constraint
replacespjit.with_sharding_constraint
, to recognize the logical axis names.
class LogicalDotReluDot(nn.Module):
depth: int
@nn.compact
def __call__(self, x):
W1 = self.param(
'W1',
spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('embed', 'hidden')),
(x.shape[-1], self.depth))
y = jax.nn.relu(jnp.dot(x, W1))
# Force a local sharding annotation.
y = spmd.with_logical_constraint(y, ('batch', 'hidden'))
W2 = self.param(
'W2',
spmd.with_logical_partitioning(nn.initializers.xavier_normal(), ('hidden', 'embed')),
(self.depth, x.shape[-1]))
z = jnp.dot(y, W2)
# Force a local sharding annotation.
z = spmd.with_logical_constraint(z, ('batch', 'embed'))
return z, None
class LogicalMLP(nn.Module):
num_layers: int
depth: int
use_scan: bool
@nn.compact
def __call__(self, x):
if self.use_scan:
x, _ = nn.scan(LogicalDotReluDot, length=self.num_layers,
variable_axes={"params": 0},
split_rngs={"params": True},
metadata_params={nn.PARTITION_NAME: 'layer'}
)(self.depth)(x)
else:
for i in range(self.num_layers):
x, _ = DotReluDot(self.depth)(x)
return x
The LogicalMLP
model definition generates a set of PartitionSpec
with logical axis names.
Repeat the steps from earlier: instantiate a model, evaluate the init_fn
abstractly, and use flax.linen.get_partition_spec
to automatically generate the PartitionSpec
:
logical_model = LogicalMLP(LAYERS, DEPTH, USE_SCAN)
logical_abstract_variables = jax.eval_shape(
functools.partial(init_fn, model=logical_model, optimizer=optimizer), k, x)
logical_output_spec = nn.get_partition_spec(logical_abstract_variables)
logical_output_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of LogicalMLP(
# attributes
num_layers = 4
depth = 1024
use_scan = True
)>, params=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec('layer', 'embed', 'hidden'),
W2: PartitionSpec('layer', 'hidden', 'embed'),
},
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec('layer', 'embed', 'hidden'),
W2: PartitionSpec('layer', 'hidden', 'embed'),
},
}), nu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec('layer', 'embed', 'hidden'),
W2: PartitionSpec('layer', 'hidden', 'embed'),
},
})), EmptyState()))
To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis 'data'
or 'model'
. This rule is a list of (logical_axis_name
, device_axis_name
) tuples, and jax.linen.spmd.logical_to_mesh
will convert them to the spec that pjit
accepts.
This allows you to change the rules and try out new partition layouts without modifying the model definition.
# Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.
rules = (('batch', 'data'),
('hidden', 'model'))
logical_state_spec = spmd.logical_to_mesh(logical_output_spec, rules)
logical_state_spec
TrainState(step=PartitionSpec(), apply_fn=<bound method Module.apply of LogicalMLP(
# attributes
num_layers = 4
depth = 1024
use_scan = True
)>, params=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=PartitionSpec(), mu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
}), nu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: PartitionSpec(None, None, 'model'),
W2: PartitionSpec(None, 'model', None),
},
})), EmptyState()))
You can verify that the logical_state_spec
here has the same content as state_spec
in the previous (“non-logical”) example. This will be the out_axis_resources
you specify when creating pjit
ted functions.
state_spec.params['ScanDotReluDot_0'] == logical_state_spec.params['ScanLogicalDotReluDot_0']
True
logical_pjit_init_fn = pjit(init_fn,
static_argnums=(2, 3),
in_axis_resources=(PartitionSpec(None), x_spec), # RNG key and x
out_axis_resources=logical_state_spec
)
with mesh:
logical_initialized_state = logical_pjit_init_fn(k, x, logical_model, optimizer)
jax.tree_map(jnp.shape, logical_initialized_state)
TrainState(step=(), apply_fn=<bound method Module.apply of LogicalMLP(
# attributes
num_layers = 4
depth = 1024
use_scan = True
)>, params=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
},
}), tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f3ba877d700>, update=<function chain.<locals>.update_fn at 0x7f3ba877d8b0>), opt_state=(ScaleByAdamState(count=(), mu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
},
}), nu=FrozenDict({
ScanLogicalDotReluDot_0: {
W1: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'embed', 'hidden'), mesh=None, rules=None),
W2: LogicallyPartitioned(value=(4, 1024, 1024), names=('layer', 'hidden', 'embed'), mesh=None, rules=None),
},
})), EmptyState()))
When to use device axis / logical axis#
Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model.
If you want a very simple model, or you are very confident of your way of partitioning, defining it with device mesh axis can potentially save you a few extra lines of code of converting the logical naming back to the device naming.
On the other hand, the logical naming helpers are useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.
In really advanced use cases, you may have more complicated sharding patterns that require annotating activation dimension names differently from parameter dimension names. When people wish to have more fine-grained control on manual mesh assignments, directly using device axis names could be more helpful.
Save the data#
You can use flax.training.checkpoints
to save the cross-device array, as shown in the Save and load checkpoints guide - Multi-host/multi-process checkpointing. This is especially required if you are running on a multi-host environment (for example, a TPU pod).
Keep in mind that to restore the arrays to the desired partition, you need to provide a sample target
pytree that has the same structure and has the desired PartitionSpec
in place for each JAX array. The PartitionSpec
you use to restore the array doesn’t necessarily need to be the same as the ones you used to store the array.