Model Surgery#

We will show how to get a flat dict of all the tensors, and then go back to a nested, frozen dict. This will be demonstrated for both Flax modules and optimizers.

Surgery with Flax Modules#

Let’s create a small convolutional neural network model for our demo.

class CNN(nn.Module):
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

def get_initial_params(rng):
  init_shape = jnp.ones((1, 28, 28, 1), jnp.float32)
  initial_params = CNN().init(rng, init_shape)['params']
  return initial_params

key = jax.random.PRNGKey(0)
params = get_initial_params(key)

print(jax.tree_map(jnp.shape, params))
    Conv_0: {
        bias: (32,),
        kernel: (3, 3, 1, 32),
    Conv_1: {
        bias: (64,),
        kernel: (3, 3, 32, 64),
    Dense_0: {
        bias: (256,),
        kernel: (3136, 256),
    Dense_1: {
        bias: (10,),
        kernel: (256, 10),

Next, get a flat dict for doing model surgery as follows:

# Get flattened-key: value list.
flat_params = traverse_util.flatten_dict(params)
print(jax.tree_map(jnp.shape, flat_params))
{('Conv_0', 'bias'): (32,),
 ('Conv_0', 'kernel'): (3, 3, 1, 32),
 ('Conv_1', 'bias'): (64,),
 ('Conv_1', 'kernel'): (3, 3, 32, 64),
 ('Dense_0', 'bias'): (256,),
 ('Dense_0', 'kernel'): (3136, 256),
 ('Dense_1', 'bias'): (10,),
 ('Dense_1', 'kernel'): (256, 10)}

After doing whatever you want, unflatten back:

# Unflatten.
unflat_params = traverse_util.unflatten_dict(flat_params)
# Refreeze.
unflat_params = freeze(unflat_params)
print(jax.tree_map(jnp.shape, unflat_params))
    Conv_0: {
        bias: (32,),
        kernel: (3, 3, 1, 32),
    Conv_1: {
        bias: (64,),
        kernel: (3, 3, 32, 64),
    Dense_0: {
        bias: (256,),
        kernel: (3136, 256),
    Dense_1: {
        bias: (10,),
        kernel: (256, 10),

Surgery with Optimizers#

When using Optax as an optimizer, the opt_state is actually a nested tuple of the states of individual gradient transformations that compose the optimizer. These states contain pytrees that mirror the parameter tree, and can be modified the same way: flattening, modifying, unflattening, and then recreating a new optimizer state that mirrors the original state.

tx = optax.adam(1.0)
opt_state = tx.init(params)

# The optimizer state is a tuple of gradient transformation states.
print(jax.tree_map(jnp.shape, opt_state))
(ScaleByAdamState(count=(), mu=FrozenDict({
    Conv_0: { bias: (32,), kernel: (3, 3, 1, 32), },
    Conv_1: { bias: (64,), kernel: (3, 3, 32, 64), },
    Dense_0: { bias: (256,), kernel: (3136, 256), },
    Dense_1: { bias: (10,), kernel: (256, 10), },
}), nu=FrozenDict({
    Conv_0: { bias: (32,), kernel: (3, 3, 1, 32), },
    Conv_1: { bias: (64,), kernel: (3, 3, 32, 64), },
    Dense_0: { bias: (256,), kernel: (3136, 256), },
    Dense_1: { bias: (10,), kernel: (256, 10), },
})), EmptyState())

The pytrees inside the optimizer state follow the same structure as the parameters and can be flattened / modified exactly the same way

flat_mu = traverse_util.flatten_dict(opt_state[0].mu)
flat_nu = traverse_util.flatten_dict(opt_state[0].nu)

print(jax.tree_map(jnp.shape, flat_mu))
{('Conv_0', 'bias'): (32,),
 ('Conv_0', 'kernel'): (3, 3, 1, 32),
 ('Conv_1', 'bias'): (64,),
 ('Conv_1', 'kernel'): (3, 3, 32, 64),
 ('Dense_0', 'bias'): (256,),
 ('Dense_0', 'kernel'): (3136, 256),
 ('Dense_1', 'bias'): (10,),
 ('Dense_1', 'kernel'): (256, 10)}

After modification, re-create optimizer state:

opt_state = (
) + opt_state[1:]