Open In Colab Open On GitHub

Flax Basics

This notebook will walk you through the following workflow:

  • Instantiating a model from Flax built-in layers or third-party models.

  • Initializing parameters of the model and manually written training.

  • Using optimizers provided by Flax to ease training.

  • Serialization of parameters and other objects.

  • Creating your own models and managing state.

Setting up our environment

Here we provide the code needed to set up the environment for our notebook.

[1]:
# Install the latest JAXlib version.
!pip install --upgrade -q pip jax jaxlib
# Install Flax at head:
!pip install --upgrade -q git+https://github.com/google/flax.git
[2]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

from jax.config import config
config.enable_omnistaging() # Linen requires enabling omnistaging

Linear regression with Flax

In the previous JAX for the impatient notebook, we finished up with a linear regression example. As we know, linear regression can also be written as a single dense neural network layer, which we will show in the following so that we can compare how it’s done.

A dense layer is a layer that has a kernel parameter \(W\in\mathcal{M}_{m,n}(\mathbb{R})\) where \(m\) is the number of features as an output of the model, and \(n\) the dimensionality of the input, and a bias parameter \(b\in\mathbb{R}^m\). The dense layers returns \(Wx+b\) from an input \(x\in\mathbb{R}^n\).

This dense layer is already provided by Flax in the flax.linen module (here imported as nn).

[3]:
# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

Layers (and models in general, we’ll use that word from now on) are subclasses of the linen.Module class.

Model parameters & initialization

Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.

[4]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_map(lambda x: x.shape, params) # Checking output shapes
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[4]:
FrozenDict({'params': {'bias': (5,), 'kernel': (10, 5)}})

Note: JAX and Flax, like NumPy, are row-based systems, meaning that vectors are represented as row vectors and not column vectors. This can be seen in the shape of the kernel here.

The result is what we expect: bias and kernel parameters of the correct size. Under the hood:

  • The dummy input variable x is used to trigger shape inference: we only declared the number of features we wanted in the output of the model, not the size of the input. Flax finds out by itself the correct size of the kernel.

  • The random PRNG key is used to trigger the initialization functions (those have default values provided by the module here).

  • Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments (PRNG Key, shape, dtype) and return an Array of shape shape.

  • The init function returns the initialized set of parameters (you can also get the output of the evaluation on the dummy input with the same syntax but using the init_with_output method instead of init.

We see in the output that parameters are stored in a FrozenDict instance which helps deal with the functional nature of JAX by preventing any mutation of the underlying dict and making the user aware of it. Read more about it in the Flax docs. As a consequence, the following doesn’t work:

[5]:
try:
    params['new_key'] = jnp.ones((2,2))
except ValueError as e:
    print("Error: ", e)
Error:  FrozenDict is immutable.

To evaluate the model with a given set of parameters (never stored with the model), we just use the apply method by providing it the parameters to use as well as the input:

[6]:
model.apply(params, x)
[6]:
DeviceArray([-0.7358944,  1.3583755, -0.7976871,  0.8168598,  0.6297793],            dtype=float32)

Gradient descent

If you jumped here directly without going through the JAX part, here is the linear regression formulation we’re going to use: from a set of data points \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\), we try to find a set of parameters \(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\) such that the function \(f_{W,b}(x)=Wx+b\) minimizes the mean squared error:

\[\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2\]

Here, we see that the tuple \((W,b)\) matches the parameters of the Dense layer. We’ll perform gradient descent using those. Let’s first generate the fake data we’ll use.

[7]:
# Set problem dimensions
nsamples = 20
xdim = 10
ydim = 5

# Generate random ground truth W and b
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (xdim, ydim))
b = random.normal(k2, (ydim,))
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise
ksample, knoise = random.split(k1)
x_samples = random.normal(ksample, (nsamples, xdim))
y_samples = jnp.dot(x_samples, W) + b
y_samples += 0.1*random.normal(knoise,(nsamples, ydim)) # Adding noise
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)

Now let’s generate the loss function (mean squared error) with that data.

[8]:
def make_mse_func(x_batched, y_batched):
  def mse(params):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
      pred = model.apply(params, x)
      return jnp.inner(y-pred, y-pred)/2.0
    # We vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
  return jax.jit(mse) # And finally we jit the result.

# Get the sampled loss
loss = make_mse_func(x_samples, y_samples)

And finally perform the gradient descent.

[9]:
alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', loss(true_params))
grad_fn = jax.value_and_grad(loss)

for i in range(101):
  # We perform one gradient update
  loss_val, grad = grad_fn(params)
  params = jax.tree_multimap(lambda old, grad: old - alpha * grad,
                            params, grad)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)
Loss for "true" W,b:  29.070158
Loss step 0:  23.618902
Loss step 10:  0.30728564
Loss step 20:  0.06495677
Loss step 30:  0.025215296
Loss step 40:  0.015619493
Loss step 50:  0.012849321
Loss step 60:  0.011984843
Loss step 70:  0.011705536
Loss step 80:  0.011613827
Loss step 90:  0.011583473
Loss step 100:  0.011573391

Build-in optimization API

Flax provides an optimization package in flax.optim to make your life easier when training models. The process is:

  1. You choose an optimization method (e.g. optim.GradientDescent, optim.Adam)

  2. From the previous optimization method, you create a wrapper around the parameters you’re going to optimize for with the create method. Your parameters are accessible through the target field.

  3. You compute the gradients of your loss with jax.value_and_grad().

  4. At every iteration, you compute the gradients at the current point, then use the apply_gradient() method on the optimizer to return a new optimizer with updated parameters.

[10]:
from flax import optim
optimizer_def = optim.GradientDescent(learning_rate=alpha) # Choose the method
optimizer = optimizer_def.create(params) # Create the wrapping optimizer with initial parameters
loss_grad_fn = jax.value_and_grad(loss)
[11]:
for i in range(101):
  loss_val, grad = loss_grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad) # Return the updated optimizer with parameters.
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)
Loss step 0:  0.011572863
Loss step 10:  0.011569859
Loss step 20:  0.011568859
Loss step 30:  0.011568523
Loss step 40:  0.011568412
Loss step 50:  0.011568374
Loss step 60:  0.011568364
Loss step 70:  0.011568359
Loss step 80:  0.01156836
Loss step 90:  0.011568356
Loss step 100:  0.011568357

Serializing the result

Now that we’re happy with the result of our training, we might want to save the model parameters to load them back later. Flax provides a serialization package to enable you to do that.

[12]:
from flax import serialization
bytes_output = serialization.to_bytes(optimizer.target)
dict_output = serialization.to_state_dict(optimizer.target)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)
Dict output
{'params': {'bias': DeviceArray([-3.023082 ,  0.5307182,  3.7256303,  1.4638226, -3.2100437],            dtype=float32), 'kernel': DeviceArray([[-1.4092493e-02,  4.8609809e-03,  1.1460093e-02,
              -6.0927689e-02,  2.0413438e-05],
             [-3.3569761e-02, -1.5614161e-03,  4.3190460e-04,
              -7.9035060e-03, -1.9851506e-02],
             [-1.8882388e-02, -2.1366426e-03, -1.8663550e-02,
              -3.0001188e-02,  5.1880259e-02],
             [-4.8119370e-02, -2.9280247e-02, -1.1992223e-02,
              -1.0111435e-02, -8.3459895e-03],
             [-1.7368369e-02, -1.7084973e-02,  6.0279824e-02,
               9.2046618e-02, -1.5414236e-02],
             [-3.0089449e-02, -5.5370983e-03, -9.1237156e-03,
               2.1827107e-02, -2.0405082e-02],
             [-5.6748122e-02, -4.2654604e-02, -1.1436724e-02,
               7.5801805e-02, -2.0075133e-02],
             [-1.4368590e-03, -1.6048675e-02,  1.5781123e-02,
               2.8437756e-03, -8.5009886e-03],
             [ 1.7892396e-02,  5.7572998e-02,  4.1483097e-02,
              -9.9685444e-03, -2.1875760e-02],
             [-2.1158390e-02, -1.3853005e-02,  2.5077526e-02,
               3.2925244e-02,  3.8115401e-02]], dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14-zA\xc0&\xdd\x07?\xbapn@\x8a^\xbb?[qM\xc0\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc83\xe4f\xbc\xddH\x9f;\x1d\xc3;<P\x8fy\xbd\x86=\xab7r\x80\t\xbdn\xa8\xcc\xbaAq\xe29\xb5}\x01\xbc\xa0\x9f\xa2\xbc=\xaf\x9a\xbc\xea\x06\x0c\xbbM\xe4\x98\xbc\r\xc5\xf5\xbce\x80T=\xd1\x18E\xbd!\xdd\xef\xbc\x07{D\xbco\xaa%\xbc\x9e\xbd\x08\xbc\x1cH\x8e\xbc\xc9\xf5\x8b\xbc\xfa\xe7v=\xf0\x82\xbc=\xfe\x8b|\xbc&~\xf6\xbc\x8cp\xb5\xbb\xa3{\x15\xbc\xc3\xce\xb2<\x8f(\xa7\xbc\xb8ph\xbd\x98\xb6.\xbd\x19a;\xbc\xfa=\x9b=\x9bt\xa4\xbc\xfdT\xbc\xba\x83x\x83\xbcjG\x81<\xa3^:;\xbbG\x0b\xbc\x13\x93\x92<\xaa\xd1k=.\xea)=\x1bS#\xbc\xcb4\xb3\xbc\\T\xad\xbc\xb7\xf7b\xbcbo\xcd<\x9f\xdc\x06=\xe5\x1e\x1c='

To load the model back, you’ll need to use as a template the model parameter structure, like the one you would get from the model initialization. Here, we use the previously generated params as a template. Note that this will produce a new variable structure, and not mutate in-place.

The point of enforcing structure through template is to avoid users issues downstream, so you need to first have the right model that generates the parameters structure.

[13]:
serialization.from_bytes(params, bytes_output)
[13]:
FrozenDict({'params': {'bias': array([-3.023082 ,  0.5307182,  3.7256303,  1.4638226, -3.2100437],
      dtype=float32), 'kernel': array([[-1.4092493e-02,  4.8609809e-03,  1.1460093e-02, -6.0927689e-02,
         2.0413438e-05],
       [-3.3569761e-02, -1.5614161e-03,  4.3190460e-04, -7.9035060e-03,
        -1.9851506e-02],
       [-1.8882388e-02, -2.1366426e-03, -1.8663550e-02, -3.0001188e-02,
         5.1880259e-02],
       [-4.8119370e-02, -2.9280247e-02, -1.1992223e-02, -1.0111435e-02,
        -8.3459895e-03],
       [-1.7368369e-02, -1.7084973e-02,  6.0279824e-02,  9.2046618e-02,
        -1.5414236e-02],
       [-3.0089449e-02, -5.5370983e-03, -9.1237156e-03,  2.1827107e-02,
        -2.0405082e-02],
       [-5.6748122e-02, -4.2654604e-02, -1.1436724e-02,  7.5801805e-02,
        -2.0075133e-02],
       [-1.4368590e-03, -1.6048675e-02,  1.5781123e-02,  2.8437756e-03,
        -8.5009886e-03],
       [ 1.7892396e-02,  5.7572998e-02,  4.1483097e-02, -9.9685444e-03,
        -2.1875760e-02],
       [-2.1158390e-02, -1.3853005e-02,  2.5077526e-02,  3.2925244e-02,
         3.8115401e-02]], dtype=float32)}})

The serialization utils provided by Flax work on objects beyond parameters, for example you might want to serialize the optimizer and it’s states, which we show in the following cell:

[14]:
serialization.to_state_dict(optimizer)
[14]:
{'target': {'params': {'bias': DeviceArray([-3.023082 ,  0.5307182,  3.7256303,  1.4638226, -3.2100437],            dtype=float32),
   'kernel': DeviceArray([[-1.4092493e-02,  4.8609809e-03,  1.1460093e-02,
                 -6.0927689e-02,  2.0413438e-05],
                [-3.3569761e-02, -1.5614161e-03,  4.3190460e-04,
                 -7.9035060e-03, -1.9851506e-02],
                [-1.8882388e-02, -2.1366426e-03, -1.8663550e-02,
                 -3.0001188e-02,  5.1880259e-02],
                [-4.8119370e-02, -2.9280247e-02, -1.1992223e-02,
                 -1.0111435e-02, -8.3459895e-03],
                [-1.7368369e-02, -1.7084973e-02,  6.0279824e-02,
                  9.2046618e-02, -1.5414236e-02],
                [-3.0089449e-02, -5.5370983e-03, -9.1237156e-03,
                  2.1827107e-02, -2.0405082e-02],
                [-5.6748122e-02, -4.2654604e-02, -1.1436724e-02,
                  7.5801805e-02, -2.0075133e-02],
                [-1.4368590e-03, -1.6048675e-02,  1.5781123e-02,
                  2.8437756e-03, -8.5009886e-03],
                [ 1.7892396e-02,  5.7572998e-02,  4.1483097e-02,
                 -9.9685444e-03, -2.1875760e-02],
                [-2.1158390e-02, -1.3853005e-02,  2.5077526e-02,
                  3.2925244e-02,  3.8115401e-02]], dtype=float32)}},
 'state': {'step': DeviceArray(101, dtype=int32),
  'param_states': {'params': {'bias': {}, 'kernel': {}}}}}

Defining your own models

Flax allows you to define your own models, which should be a bit more complicated than a linear regression. In this section, we’ll show you how to build simple models. To do so, you’ll need to create subclasses of the base nn.Module class.

Keep in mind that we imported linen as nn and this only works with the new linen API

Module basics

The base abstraction for models is the nn.Module class, and every type of predefined layers in Flax (like the previous Dense) is a subclass of nn.Module. Let’s take a look and start by defining a simple but custom multi-layer perceptron i.e. a sequence of Dense layers interleaved with calls to a non-linear activation function.

[15]:
class ExplicitMLP(nn.Module):
  features: Sequence[int]

  def setup(self):
    # we automatically know what to do with lists, dicts of submodules
    self.layers = [nn.Dense(feat) for feat in self.features]
    # for single submodules, we would just write:
    # self.layer1 = nn.Dense(self, feat1)

  def __call__(self, inputs):
    x = inputs
    for i, lyr in enumerate(self.layers):
      x = lyr(x)
      if i != len(self.layers) - 1:
        x = nn.relu(x)
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292822e-02 -4.3807123e-02  2.9323796e-02  6.5492545e-03
  -1.7147183e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024032e-04  2.7864395e-05  2.4478821e-04  8.1344310e-04
  -1.0110770e-03]]

As we can see, a nn.Module subclass is made of:

  • A collection of data fields (nn.Module are Python dataclasses) - here we only have the features field of type Sequence[int].

  • A setup() method that is being called at the end of the __postinit__ where you can register submodules, variables, parameters you will need in your model.

  • A __call__ function that returns the output of the model from a given input.

  • The model structure defines a pytree of parameters following the same tree structure as the model: the params tree contains one layers_n sub dict per layer, and each of those contain the parameters of the associated Dense layer. The layout is very explicit.

Note: lists are mostly managed as you would expect (WIP), there are corner cases you should be aware of as pointed out here

Since the module structure and its parameters are not tied to each other, you can’t call directly model(x) on a given input as it will return an error. The __call__ function is being wrapped up in the apply one, which is the one to call on an input:

[16]:
try:
    y = model(x) # Returns an error
except ValueError as e:
    print(e)
Can't call methods on orphaned modules

Since here we have a very simple model, we could have used an alternative (but equivalent) way of declaring the submodules inline in the __call__ using the @nn.compact annotation like so:

[17]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
      # providing a name is optional though!
      # the default autonames would be "Dense_0", "Dense_1", ...
    return x

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 4.2292822e-02 -4.3807123e-02  2.9323796e-02  6.5492545e-03
  -1.7147183e-02]
 [ 1.2967804e-01 -1.4551792e-01  9.4432175e-02  1.2521386e-02
  -4.5417294e-02]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
   0.0000000e+00]
 [ 9.3024032e-04  2.7864395e-05  2.4478821e-04  8.1344310e-04
  -1.0110770e-03]]

There are, however, a few differences you should be aware of between the two declaration modes:

  • In setup, you are able to name some sublayers and keep them around for further use (e.g. encoder/decoder methods in autoencoders).

  • If you want to have multiple methods, then you need to declare the module using setup, as the @nn.compact annotation only allows one method to be annotated.

  • The last initialization will be handled differently see these notes for more details (TODO: add notes link)

Module parameters

In the previous MLP example, we relied only on predefined layers and operators (Dense, relu). Let’s imagine that you didn’t have a Dense layer provided by Flax and you wanted to write it on your own. Here is what it would look like using the @nn.compact way to declare a new modules:

[18]:
class SimpleDense(nn.Module):
  features: int
  kernel_init: Callable = nn.initializers.lecun_normal()
  bias_init: Callable = nn.initializers.zeros

  @nn.compact
  def __call__(self, inputs):
    kernel = self.param('kernel',
                        self.kernel_init, # Initialization function
                        (inputs.shape[-1], self.features))  # shape info.
    y = lax.dot_general(inputs, kernel,
                        (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?
    bias = self.param('bias', self.bias_init, (self.features,))
    y = y + bias
    return y

key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))

model = SimpleDense(features=3)
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameters:\n', params)
print('output:\n', y)
initialized parameters:
 FrozenDict({'params': {'kernel': DeviceArray([[ 0.6503669 ,  0.8678979 ,  0.46042678],
             [ 0.05673932,  0.9909285 , -0.63536596],
             [ 0.76134115, -0.3250529 , -0.6522163 ],
             [-0.8243032 ,  0.4150194 ,  0.19405058]], dtype=float32), 'bias': DeviceArray([0., 0., 0.], dtype=float32)}})
output:
 [[ 0.50355184  1.8548559  -0.4270196 ]
 [ 0.02790972  0.5589246  -0.43061778]
 [ 0.35471287  1.5740999  -0.32865524]
 [ 0.52648634  1.2928859   0.10089307]]

Here, we see how both declare and assign a parameter to the model using the self.param method. It takes as input (name, init_fn, *init_args) :

  • name is simply the name of the parameter that will end up in the parameter structure.

  • init_fun is a function with input (PRNGKey, *init_args) returning an Array with init_args the arguments needed to call the initialisation function

  • init_args the arguments to provide to the initialization function.

Such params can also be declared in the setup method, it won’t be able to use shape inference because Flax is using lazy initialization at the first call site.

Variables and collections of variables

As we’ve seen so far, working with models means working with:

  • A subclass of nn.Module;

  • A pytree of parameters for the model (typically from model.init());

However this is not enough to cover everything that we would need for machine learning, especially neural networks. In some cases, you might want your neural network to keep track of some internal state while it runs (e.g. batch normalization layers). There is a way to declare variables beyond the parameters of the model with the variable method.

For demonstration purposes, we’ll implement a simplified but similar mechanism to batch normalization: we’ll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation here.

[19]:
class BiasAdderWithRunningMean(nn.Module):
  decay: float = 0.99

  @nn.compact
  def __call__(self, x):
    # easy pattern to detect if we're initializing via empty variable tree
    is_initialized = self.has_variable('batch_stats', 'mean')
    ra_mean = self.variable('batch_stats', 'mean',
                            lambda s: jnp.zeros(s),
                            x.shape[1:])
    mean = ra_mean.value # This will get either the value, or trigger init
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])
    if is_initialized:
      ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)

    return x - ra_mean.value + bias


key1, key2 = random.split(random.PRNGKey(0), 2)
x = jnp.ones((10,5))
model = BiasAdderWithRunningMean()
variables = model.init(key1, x)
print('initialized variables:\n', variables)
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
print('updated state:\n', updated_state)
initialized variables:
 FrozenDict({'batch_stats': {'mean': DeviceArray([0., 0., 0., 0., 0.], dtype=float32)}, 'params': {'bias': DeviceArray([0., 0., 0., 0., 0.], dtype=float32)}})
updated state:
 FrozenDict({'batch_stats': {'mean': DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}})

Here, updated_state returns only the state variables that are being mutated by the model while applying it on data. To update the variables and get the new parameters of the model, we can use the following pattern:

[20]:
for val in [1.0, 2.0, 3.0]:
  x = val * jnp.ones((10,5))
  y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  old_state, params = variables.pop('params')
  variables = freeze({'params': params, **updated_state})
  print('updated state:\n', updated_state) # Shows only the mutable part
updated state:
 FrozenDict({'batch_stats': {'mean': DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}})
updated state:
 FrozenDict({'batch_stats': {'mean': DeviceArray([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32)}})
updated state:
 FrozenDict({'batch_stats': {'mean': DeviceArray([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32)}})

From this simplified example, you should be able to derive a full BatchNorm implementation, or any layer involving a state. To finish, let’s add an optimizer to see how to play with both parameters updated by an optimizer and state variables.

This example isn’t doing anything and is only for demonstration purposes.

[21]:
def update_step(apply_fun, x, optimizer, state):

  def loss(params):
    y, updated_state = apply_fun({'params': params, **state},
                                 x, mutable=list(state.keys()))
    l = ((x - y) ** 2).sum()
    return l, updated_state

  (l, updated_state), grads = jax.value_and_grad(
      loss, has_aux=True)(optimizer.target)
  optimizer = optimizer.apply_gradient(grads)
  return optimizer, updated_state

variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
del variables
optimizer = optim.sgd.GradientDescent(learning_rate=0.02).create(params)
x = jnp.ones((10,5))

for _ in range(3):
  optimizer, state = update_step(model.apply, x, optimizer, state)
  print('Updated state: ', state)
Updated state:  FrozenDict({'batch_stats': {'mean': DeviceArray([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}})
Updated state:  FrozenDict({'batch_stats': {'mean': DeviceArray([[0.0199, 0.0199, 0.0199, 0.0199, 0.0199]], dtype=float32)}})
Updated state:  FrozenDict({'batch_stats': {'mean': DeviceArray([[0.029701, 0.029701, 0.029701, 0.029701, 0.029701]], dtype=float32)}})

Exporting to Tensorflow’s SavedModel with jax2tf

JAX released an experimental converter called jax2tf, which allows converting trained Flax models into Tensorflow’s SavedModel format (so it can be used for TF Hub, TF.lite, TF.js, or other downstream applications). The repository contains more documentation and has various examples for Flax.