Open In Colab Open On GitHub

JAX for the Impatient

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Here we will cover the basics of JAX so that you can get started with Flax, however we very much recommend that you go through JAX’s documentation here after going over the basics here.


Let’s start by exploring the NumPy API coming from JAX and the main differences you should be aware of.

import jax
from jax import numpy as jnp, random

import numpy as np # We import the standard NumPy library

jax.numpy is the NumPy-like API that needs to be imported, and we will also use jax.random to generate some data to work on.

Let’s start by generating some matrices, and then try matrix multiplication.

m = jnp.ones((4,4)) # We're generating one 4 by 4 matrix filled with ones.
n = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]]) # An explicit 2 by 4 array
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)

Arrays in JAX are represented as DeviceArray instances and are agnostic to the place where the array lives (CPU, GPU, or TPU). This is why we’re getting the warning that no GPU/TPU was found and JAX is falling back to a CPU (unless you’re running it in an environment that has a GPU/TPU available).

We can obviously multiply matrices like we would do in NumPy.

[3]:, m).block_until_ready() # Note: yields the same result as
DeviceArray([[10., 10., 10., 10.],
             [26., 26., 26., 26.]], dtype=float32)

DeviceArray instances are actually futures (more here) due to the default asynchronous execution in JAX. For that reason, the Python call might return before the computation actually ends, hence we’re using the block_until_ready() method to ensure we return the end result.

JAX is fully compatible with NumPy, and can transparently process arrays from one library to the other.

x = np.random.normal(size=(4,4)) # Creating one standard NumPy array instance,m)
DeviceArray([[-0.8318497, -0.8318497, -0.8318497, -0.8318497],
             [ 2.4768949,  2.4768949,  2.4768949,  2.4768949],
             [-1.0424521, -1.0424521, -1.0424521, -1.0424521],
             [-3.4560933, -3.4560933, -3.4560933, -3.4560933]],            dtype=float32)

If you’re using accelerators, using NumPy arrays directly will result in multiple transfers from CPU to GPU/TPU memory. You can save that transfer bandwidth, either by creating directly a DeviceArray or by using jax.device_put on the NumPy array. With DeviceArrays, computation is done on device so no additional data transfer is required, e.g., long_vector) will only transfer a single scalar (result of the computation) back from device to host.

x = np.random.normal(size=(4,4))
x = jax.device_put(x)
DeviceArray([[ 0.08149499,  0.07987174,  1.1451471 , -0.59535813],
             [ 0.86550283,  0.6078417 ,  0.7539637 ,  1.5923587 ],
             [ 0.8374219 , -0.07827665,  1.4592382 ,  1.4161737 ],
             [ 0.37525675, -0.8032943 ,  2.062778  , -0.15352985]],            dtype=float32)

Conversely, if you want to get back a Numpy array from a JAX array, you can simply do so by using it in the Numpy API.

x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
array([[1., 2., 3., 4.],
       [5., 6., 7., 8.]], dtype=float32)


JAX is functional by essence, one practical consequence being that JAX arrays are immutable. This means no in-place ops and sliced assignments. More generally, functions should not take input or produce output using a global state.

x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
updated = jax.ops.index_update(x,(0,0), 3.0) # whereas x[0,0] = 3.0 would fail
print("x: \n", x) # Note that x didn't change, no in-place mutation.
print("updated: \n", updated)
 [[1. 2. 3. 4.]
 [5. 6. 7. 8.]]
 [[3. 2. 3. 4.]
 [5. 6. 7. 8.]]

Index operators can be found in jax.ops and follow the index_* pattern. To create an index with that syntax, you can use the jax.ops.index syntactic sugar.

x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
jax.ops.index_update(x,jax.ops.index[0,:], 3.0) # Same as x[O,:] = 3.0 in NumPy.
DeviceArray([[3., 3., 3., 3.],
             [5., 6., 7., 8.]], dtype=float32)

Finally, a more concise (and modern) way would be to use the .at attribute that plays the same syntactic sugar role as previously:

[9]:[0,:].set(3.0) # Note: this returns a new array and doesn't mutate in place.
DeviceArray([[3., 3., 3., 3.],
             [5., 6., 7., 8.]], dtype=float32)

All jax ops are available with this syntax, including: set, add, mul, min, max.

Managing randomness

In JAX, randomness is managed in a very specific way, and you can read more on JAX’s docs here (we borrow content from there!). As the JAX team puts it:

JAX implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating a PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

In short, you need to explicitly manage the PRNGs (pseudo random number generators) and their states. In JAX’s PRNGs, the state is represented as a pair of two unsigned-int32s that is called a key (there is no special meaning to the two unsigned int32s – it’s just a way of representing a uint64).

key = random.PRNGKey(0)
DeviceArray([0, 0], dtype=uint32)

If you use this key multiple times, you’ll get the same “random” output each time. To generate further entries in the sequence, you’ll need to split the PRNG and thus generate a new pair of keys.

for i in range(3):
    print("Printing the random number using key: ", key, " gives: ", random.normal(key,shape=(1,))) # Boringly not that random since we use the same key
Printing the random number using key:  [0 0]  gives:  [-0.20584235]
Printing the random number using key:  [0 0]  gives:  [-0.20584235]
Printing the random number using key:  [0 0]  gives:  [-0.20584235]
print("old key", key, "--> normal", random.normal(key, shape=(1,)))
key, subkey = random.split(key)
print("    \---SPLIT --> new key   ", key, "--> normal", random.normal(key, shape=(1,)) )
print("             \--> new subkey", subkey, "--> normal", random.normal(subkey, shape=(1,)) )
old key [0 0] --> normal [-0.20584235]
    \---SPLIT --> new key    [4146024105  967050713] --> normal [0.14389044]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]

You can also generate multiple subkeys at once if needed:

key, *subkeys = random.split(key, 4)
key, subkeys
(array([3306097435, 3899823266], dtype=uint32),
 [array([147607341, 367236428], dtype=uint32),
  array([2280136339, 1907318301], dtype=uint32),
  array([ 781391491, 1939998335], dtype=uint32)])

You can think about those PRNGs as trees of keys that match the structure of your models, which is important for reproducibility and soundness of the random behavior that you expect.

Gradients and autodiff

For a full overview of JAX’s automatic differentiation system, you can check the Autodiff Cookbook.

Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix. You can read more in the JAX cookbook notebook mentioned above.


JAX provides first-class support for gradients and automatic differentiation in functions. This is also where the functional paradigm shines, since gradients on functions are essentially stateless operations. If we consider a simple function \(f:\mathbb{R}^n\rightarrow\mathbb{R}\)

\[f(x) = \frac{1}{2} x^T x\]

with the (known) gradient:

\[\nabla f(x) = x\]
key = random.PRNGKey(0)
def f(x):

v = jnp.ones((4,))
DeviceArray(2., dtype=float32)

JAX computes the gradient as an operator acting on functions with jax.grad. Note that this only works for scalar valued functions.

Let’s take the gradient of f and make sure it matches the identity map.

v = random.normal(key,(4,))
print("Original v:")
print("Gradient of f taken at point v")
print(jax.grad(f)(v)) # should be equal to v !
Original v:
[ 1.8160859  -0.7548852   0.33988902 -0.5348355 ]
Gradient of f taken at point v
[ 1.8160859  -0.7548852   0.33988902 -0.5348355 ]

As previously mentioned, jax.grad only works for scalar-valued functions. JAX can also handle general vector valued functions. The most useful primitives are a Jacobian-Vector product - jax.jvp - and a Vector-Jacobian product - jax.vjp.

Jacobian-Vector product

Let’s consider a map \(f:\mathbb{R}^n\rightarrow\mathbb{R}^m\). As a reminder, the differential of f is the map \(df:\mathbb{R}^n \rightarrow \mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)\) where \(\mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)\) is the space of linear maps from \(\mathbb{R}^n\) to \(\mathbb{R}^m\) (hence \(df(x)\) is often represented as a Jacobian matrix). The linear approximation of f at point \(x\) reads:

\[f(x+v) = f(x) + df(x)\bullet v + o(v)\]

The \(\bullet\) operator means you are applying the linear map \(df(x)\) to the vector v.

Even though you are rarely interested in computing the full Jacobian matrix representing the linear map \(df(x)\) in a standard basis, you are often interested in the quantity \(df(x)\bullet v\). This is exactly what jax.jvp is for, and jax.jvp(f, (x,), (v,)) returns the tuple:

\[(f(x), df(x)\bullet v)\]

Let’s use a simple function as an example: \(f(x) = \frac{1}{2}({x_1}^2, {x_2}^2, \ldots, {x_n}^2)\) where we know that \(df(x)\bullet h = (x_1h_1, x_2h_2,\ldots,x_nh_n)\). Hence using jax.jvp with \(h= (1,1,\ldots,1)\) should return \(x\) as an output.

def f(x):
  return jnp.multiply(x,x)/2.0

x = random.normal(key, (5,))
v = jnp.ones(5)
print("jax.jvp(f, (x,),(v,))")
print(jax.jvp(f, (x,),(v,)))
(DeviceArray([ 0.18784378, -1.2833427 , -0.27109176,  1.2490592 ,
              0.24446994], dtype=float32), DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278],            dtype=float32))
jax.jvp(f, (x,),(v,))
(DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278],            dtype=float32), DeviceArray([ 0.18784378, -1.2833427 , -0.27109176,  1.2490592 ,
              0.24446994], dtype=float32))

Vector-Jacobian product

Keeping our \(f:\mathbb{R}^n\rightarrow\mathbb{R}^m\) it’s often the case (for example, when you are working with a scalar loss function) that you are interested in the composition \(x\rightarrow\phi\circ f(x)\) where \(\phi :\mathbb{R}^m\rightarrow\mathbb{R}\). In that case, the gradient reads:

\[\nabla(\phi\circ f)(x) = J_f(x)^T\nabla\phi(f(x))\]

Where \(J_f(x)\) is the Jacobian matrix of f evaluated at x, meaning that \(df(x)\bullet v = J_f(x)v\).

jax.vjp(f,x) returns the tuple:

\[(f(x),v\rightarrow v^TJ_f(x))\]

Keeping the same example as previously, using \(v=(1,\ldots,1)\), applying the VJP function returned by JAX should return the \(x\) value:

(val, jvp_fun) = jax.vjp(f,x)
print("x = ", x)
print("v^T Jf(x) = ", jvp_fun(jnp.ones((5,)))[0])
x =  [ 0.18784378 -1.2833427  -0.27109176  1.2490592   0.24446994]
v^T Jf(x) =  [ 0.18784378 -1.2833427  -0.27109176  1.2490592   0.24446994]

Accelerating code with jit & ops vectorization

We borrow the following example from the JAX quickstart.


JAX uses the XLA compiler under the hood, and enables you to jit compile your code to make it faster and more efficient. This is the purpose of the @jit annotation.

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

v = random.normal(key, (1000000,))
%timeit selu(v).block_until_ready()
1.96 ms ± 86.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Now using the jit annotation (or function here) to speed things up:

selu_jit = jax.jit(selu)
%timeit selu_jit(v).block_until_ready()
405 µs ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

jit compilation can be used along with autodiff in the code transparently.


Finally, JAX enables you to write code that applies to a single example, and then vectorize it to manage transparently batching dimensions.

mat = random.normal(key, (15, 10))
batched_x = random.normal(key, (5, 10)) # Batch size on axis 0
single = random.normal(key, (10,))

def apply_matrix(v):
  return, v)

print("Single apply shape: ", apply_matrix(single).shape)
print("Batched example shape: ", jax.vmap(apply_matrix)(batched_x).shape)
Single apply shape:  (15,)
Batched example shape:  (5, 15)

Full example: linear regression

Let’s implement one of the simplest models using everything we have seen so far: a linear regression. From a set of data points \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\), we try to find a set of parameters \(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\) such that the function \(f_{W,b}(x)=Wx+b\) minimizes the mean squared error:

\[\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2\]

(Note: depending on how you cast the regression problem you might end up with different setups. Theoretically we should be minimizing the expectation of the loss wrt to the data distribution, however for the sake of simplicity here we consider only the sampled loss).

# Linear feed-forward.
def predict(W, b, x):
  return, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
[ ]:
# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

# Ensure we jit the largest-possible jittable block.
def update_params(W, b, x, y, lr):
  W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)
  return W, b

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse(W_hat, b_hat, x_samples, y_samples))
Running cells with 'Python 3.7.3 64-bit' requires ipykernel package.

Run the following command to install 'ipykernel' into the Python environment.

Command: '/usr/bin/python3 -m pip install ipykernel -U --user --force-reinstall'

This is obviously an approximate solution to the linear regression problem (solving it would require a bit more work!), but here you have all the tools you would need if you wanted to do it the proper way.

Refining a bit with pytrees

Here we’re going to elaborate on our previous example using JAX pytree data structure.

Pytrees basics

The JAX ecosystem uses pytrees everywhere and we do as well in Flax (the previous FrozenDict example is one, we’ll get back to this). For a complete overview, we suggest that you take a look at the pytree page from JAX’s doc:

In JAX, a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts (JAX can be extended to consider other container types as pytrees, see Extending pytrees below). A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.

[1, "a", object()] # 3 leaves: 1, "a" and object()

(1, (2, 3), ()) # 3 leaves: 1, 2 and 3

[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves: 1, 2, 3, 4, 5

JAX provides a few utilities to work with pytrees that live in the tree_util package.

from jax import tree_util

t = [1, {"k1": 2, "k2": (3, 4)}, 5]

You will often come across tree_map function that maps a function f to a tree and its leaves. We used it in the previous section to display the shapes of the model’s parameters.

tree_util.tree_map(lambda x: x*x, t)
[1, {'k1': 4, 'k2': (9, 16)}, 25]

Instead of applying a standalone function to each of the tree leaves, you can also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function.

t2 = tree_util.tree_map(lambda x: x*x, t)
tree_util.tree_map(lambda x,y: x+y, t, t2)
[2, {'k1': 6, 'k2': (12, 20)}, 30]

Linear regression with Pytrees

Whereas our previous example was perfectly fine, we can see that when things get more complicated (as they will with neural networks), it will be harder to manage parameters of the models as we did.

Here we show an alternative based on pytrees, using the same data from the previous example. Now, our params is a pytree containing both the W and b entries.

# Linear feed-forward that takes a params pytree.
def predict_pytree(params, x):
  return, params['W']) + params['b']

# Loss function: Mean squared error.
def mse_pytree(params, x_batched,y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x,y):
    y_pred = predict_pytree(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

# Initialize estimated W and b with zeros. Store in a pytree.
params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}

The great thing is that JAX is able to handle differentiation with respect to pytree parameters:


jax.grad(mse_pytree)(params, x_samples, y_samples)
{'W': DeviceArray([[-1.9287349e+00,  4.2963755e-01,  7.1613449e-01,
                2.1056123e+00,  5.0405121e-01, -2.4983375e+00,
               -6.3854176e-01, -2.2620213e+00, -1.3365206e+00,
              [ 1.1999468e+00, -9.4563609e-01, -1.0878400e+00,
               -7.0340711e-01,  3.3224609e-01,  1.7538791e+00,
               -7.1916544e-01,  1.0927428e+00, -1.4491037e+00,
              [-1.4826509e+00, -7.6116532e-01,  2.2319858e-01,
               -3.0391946e-01,  3.0397055e+00, -3.8419428e-01,
               -1.8290073e+00, -2.3353369e+00, -1.1087127e+00,
              [ 8.2374442e-01, -9.9650609e-01, -7.6030111e-01,
                6.3919222e-01, -6.0864899e-02, -1.0859716e+00,
                1.2923398e+00, -4.9342898e-01, -1.4711156e-03,
              [-4.5656446e-01, -1.3063025e-01, -3.9179009e-01,
                2.1743817e+00, -5.3948693e-02,  4.5653123e-01,
               -8.5279423e-01,  1.1709594e+00,  9.6438813e-01,
               -2.3813749e-02]], dtype=float32),
 'b': DeviceArray([ 1.0923628,  1.3121076, -2.9304824, -0.6492362,  1.1531248],            dtype=float32)}

Now using our tree of params, we can write the gradient descent in a simpler way using jax.tree_map:

# Always remember to jit!
def update_params_pytree(params, learning_rate, x_samples, y_samples):
  params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(mse_pytree)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse_pytree({'W': W, 'b': b}))
for i in range(101):
  # Perform one gradient update.
  params = update_params_pytree(params, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse_pytree(params))
Loss for "true" W,b:  0.023639774
Loss step 0:  11.096583
Loss step 5:  1.1743388
Loss step 10:  0.32879353
Loss step 15:  0.1398177
Loss step 20:  0.07359565
Loss step 25:  0.04415301
Loss step 30:  0.029408678
Loss step 35:  0.021554656
Loss step 40:  0.017227933
Loss step 45:  0.014798875
Loss step 50:  0.013420242
Loss step 55:  0.0126327025
Loss step 60:  0.0121810865
Loss step 65:  0.011921468
Loss step 70:  0.011771992
Loss step 75:  0.011685831
Loss step 80:  0.011636148
Loss step 85:  0.011607475
Loss step 90:  0.011590928
Loss step 95:  0.011581394
Loss step 100:  0.011575883

Besides jax.grad(), another useful function is jax.value_and_grad(), which returns the value of the input function and of its gradient.

To switch from jax.grad() to jax.value_and_grad(), replace the training loop above with the following:

[ ]:
# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(mse_pytree)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)

That’s all you needed to know to get started with Flax! To dive deeper, we very much recommend checking the JAX docs.