Open in Colab Open On GitHub

JAX for the Impatient#

JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

Here we will cover the basics of JAX so that you can get started with Flax, however we very much recommend that you go through JAX’s documentation here after going over the basics here.

NumPy API#

Let’s start by exploring the NumPy API coming from JAX and the main differences you should be aware of.

import jax
from jax import numpy as jnp, random

import numpy as np # We import the standard NumPy library 

jax.numpy is the NumPy-like API that needs to be imported, and we will also use jax.random to generate some data to work on.

Let’s start by generating some matrices, and then try matrix multiplication.

m = jnp.ones((4,4)) # We're generating one 4 by 4 matrix filled with ones.
n = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]]) # An explicit 2 by 4 array
m
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray([[1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.],
             [1., 1., 1., 1.]], dtype=float32)

Arrays in JAX are represented as DeviceArray instances and are agnostic to the place where the array lives (CPU, GPU, or TPU). This is why we’re getting the warning that no GPU/TPU was found and JAX is falling back to a CPU (unless you’re running it in an environment that has a GPU/TPU available).

We can obviously multiply matrices like we would do in NumPy.

jnp.dot(n, m).block_until_ready() # Note: yields the same result as np.dot(m)
DeviceArray([[10., 10., 10., 10.],
             [26., 26., 26., 26.]], dtype=float32)

DeviceArray instances are actually futures (more here) due to the default asynchronous execution in JAX. For that reason, the Python call might return before the computation actually ends, hence we’re using the block_until_ready() method to ensure we return the end result.

JAX is fully compatible with NumPy, and can transparently process arrays from one library to the other.

x = np.random.normal(size=(4,4)) # Creating one standard NumPy array instance
jnp.dot(x,m)
DeviceArray([[-0.7259693, -0.7259693, -0.7259693, -0.7259693],
             [-1.3710139, -1.3710139, -1.3710139, -1.3710139],
             [-4.3132505, -4.3132505, -4.3132505, -4.3132505],
             [ 2.3098469,  2.3098469,  2.3098469,  2.3098469]],            dtype=float32)

If you’re using accelerators, using NumPy arrays directly will result in multiple transfers from CPU to GPU/TPU memory. You can save that transfer bandwidth, either by creating directly a DeviceArray or by using jax.device_put on the NumPy array. With DeviceArrays, computation is done on device so no additional data transfer is required, e.g. jnp.dot(long_vector, long_vector) will only transfer a single scalar (result of the computation) back from device to host.

x = np.random.normal(size=(4,4))
x = jax.device_put(x)
x
DeviceArray([[-0.6048319 , -0.6316572 ,  1.2123754 , -0.11620954],
             [ 1.5050535 , -0.42365703, -1.1069435 ,  0.7033215 ],
             [-1.376763  ,  0.10704198,  0.43705946,  0.411347  ],
             [-1.4015176 , -0.7026075 , -1.0267633 , -0.532106  ]],            dtype=float32)

Conversely, if you want to get back a Numpy array from a JAX array, you can simply do so by using it in the Numpy API.

x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
np.array(x)
array([[1., 2., 3., 4.],
       [5., 6., 7., 8.]], dtype=float32)

(Im)mutability#

JAX is functional by essence, one practical consequence being that JAX arrays are immutable. This means no in-place ops and sliced assignments. More generally, functions should not take input or produce output using a global state.

x = jnp.array([[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0]])
updated = x.at[0, 0].set(3.0) # whereas x[0,0] = 3.0 would fail
print("x: \n", x) # Note that x didn't change, no in-place mutation.
print("updated: \n", updated)
x: 
 [[1. 2. 3. 4.]
 [5. 6. 7. 8.]]
updated: 
 [[3. 2. 3. 4.]
 [5. 6. 7. 8.]]

All jax ops are available with this syntax, including: set, add, mul, min, max.

Managing randomness#

In JAX, randomness is managed in a very specific way, and you can read more on JAX’s docs here (we borrow content from there!). As the JAX team puts it:

JAX implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating a PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.

In short, you need to explicitly manage the PRNGs (pseudo random number generators) and their states. In JAX’s PRNGs, the state is represented as a pair of two unsigned-int32s that is called a key (there is no special meaning to the two unsigned int32s – it’s just a way of representing a uint64).

key = random.PRNGKey(0)
key
DeviceArray([0, 0], dtype=uint32)

If you use this key multiple times, you’ll get the same “random” output each time. To generate further entries in the sequence, you’ll need to split the PRNG and thus generate a new pair of keys.

for i in range(3):
    print("Printing the random number using key: ", key, " gives: ", random.normal(key,shape=(1,))) # Boringly not that random since we use the same key
Printing the random number using key:  
[0 0]  gives:  [-0.20584226]
Printing the random number using key:  [0 0]  gives:  [-0.20584226]
Printing the random number using key:  [0 0]  gives:  [-0.20584226]
print("old key", key, "--> normal", random.normal(key, shape=(1,)))
key, subkey = random.split(key)
print("    \---SPLIT --> new key   ", key, "--> normal", random.normal(key, shape=(1,)) )
print("             \--> new subkey", subkey, "--> normal", random.normal(subkey, shape=(1,)) )
old key [0 0] --> normal [-0.20584226]
    \---SPLIT --> new key    [4146024105  967050713] --> normal [0.14389051]
             \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]

You can also generate multiple subkeys at once if needed:

key, *subkeys = random.split(key, 4)
key, subkeys
(DeviceArray([3306097435, 3899823266], dtype=uint32),
 [DeviceArray([147607341, 367236428], dtype=uint32),
  DeviceArray([2280136339, 1907318301], dtype=uint32),
  DeviceArray([ 781391491, 1939998335], dtype=uint32)])

You can think about those PRNGs as trees of keys that match the structure of your models, which is important for reproducibility and soundness of the random behavior that you expect.

Gradients and autodiff#

For a full overview of JAX’s automatic differentiation system, you can check the Autodiff Cookbook.

Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix. You can read more in the JAX cookbook notebook mentioned above.

Gradients#

JAX provides first-class support for gradients and automatic differentiation in functions. This is also where the functional paradigm shines, since gradients on functions are essentially stateless operations. If we consider a simple function \(f:\mathbb{R}^n\rightarrow\mathbb{R}\)

\[f(x) = \frac{1}{2} x^T x\]

with the (known) gradient: $\(\nabla f(x) = x\)$

key = random.PRNGKey(0)
def f(x):
  return jnp.dot(x.T,x)/2.0

v = jnp.ones((4,))
f(v)
DeviceArray(2., dtype=float32)

JAX computes the gradient as an operator acting on functions with jax.grad. Note that this only works for scalar valued functions.

Let’s take the gradient of f and make sure it matches the identity map.

v = random.normal(key,(4,))
print("Original v:")
print(v)
print("Gradient of f taken at point v")
print(jax.grad(f)(v)) # should be equal to v !
Original v:
[ 1.8160863  -0.75488514  0.33988908 -0.53483534]
Gradient of f taken at point v
[ 1.8160863  -0.75488514  0.33988908 -0.53483534]

As previously mentioned, jax.grad only works for scalar-valued functions. JAX can also handle general vector valued functions. The most useful primitives are a Jacobian-Vector product - jax.jvp - and a Vector-Jacobian product - jax.vjp.

Jacobian-Vector product#

Let’s consider a map \(f:\mathbb{R}^n\rightarrow\mathbb{R}^m\). As a reminder, the differential of f is the map \(df:\mathbb{R}^n \rightarrow \mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)\) where \(\mathcal{L}(\mathbb{R}^n,\mathbb{R}^m)\) is the space of linear maps from \(\mathbb{R}^n\) to \(\mathbb{R}^m\) (hence \(df(x)\) is often represented as a Jacobian matrix). The linear approximation of f at point \(x\) reads: $\(f(x+v) = f(x) + df(x)\bullet v + o(v)\)$

The \(\bullet\) operator means you are applying the linear map \(df(x)\) to the vector v.

Even though you are rarely interested in computing the full Jacobian matrix representing the linear map \(df(x)\) in a standard basis, you are often interested in the quantity \(df(x)\bullet v\). This is exactly what jax.jvp is for, and jax.jvp(f, (x,), (v,)) returns the tuple: $\((f(x), df(x)\bullet v)\)$

Let’s use a simple function as an example: \(f(x) = \frac{1}{2}({x_1}^2, {x_2}^2, \ldots, {x_n}^2)\) where we know that \(df(x)\bullet h = (x_1h_1, x_2h_2,\ldots,x_nh_n)\). Hence using jax.jvp with \(h= (1,1,\ldots,1)\) should return \(x\) as an output.

def f(x):
  return jnp.multiply(x,x)/2.0

x = random.normal(key, (5,))
v = jnp.ones(5)
print("(x,f(x))")
print((x,f(x)))
print("jax.jvp(f, (x,),(v,))")
print(jax.jvp(f, (x,),(v,)))
(x,f(x))
(DeviceArray([ 0.18784384, -1.2833426 , -0.2710917 ,  1.2490594 ,
              0.24447003], dtype=float32), DeviceArray([0.01764265, 0.8234841 , 0.03674535, 0.7800747 , 0.0298828 ],            dtype=float32))
jax.jvp(f, (x,),(v,))
(DeviceArray([0.01764265, 0.8234841 , 0.03674535, 0.7800747 , 0.0298828 ],            dtype=float32), DeviceArray([ 0.18784384, -1.2833426 , -0.2710917 ,  1.2490594 ,
              0.24447003], dtype=float32))

Vector-Jacobian product#

Keeping our \(f:\mathbb{R}^n\rightarrow\mathbb{R}^m\) it’s often the case (for example, when you are working with a scalar loss function) that you are interested in the composition \(x\rightarrow\phi\circ f(x)\) where \(\phi :\mathbb{R}^m\rightarrow\mathbb{R}\). In that case, the gradient reads: $\(\nabla(\phi\circ f)(x) = J_f(x)^T\nabla\phi(f(x))\)$

Where \(J_f(x)\) is the Jacobian matrix of f evaluated at x, meaning that \(df(x)\bullet v = J_f(x)v\).

jax.vjp(f,x) returns the tuple: $\((f(x),v\rightarrow v^TJ_f(x))\)$

Keeping the same example as previously, using \(v=(1,\ldots,1)\), applying the VJP function returned by JAX should return the \(x\) value:

(val, jvp_fun) = jax.vjp(f,x)
print("x = ", x)
print("v^T Jf(x) = ", jvp_fun(jnp.ones((5,)))[0])
x =  [ 0.18784384 -1.2833426  -0.2710917   1.2490594   0.24447003]
v^T Jf(x) =  [ 0.18784384 -1.2833426  -0.2710917   1.2490594   0.24447003]

Accelerating code with jit & ops vectorization#

We borrow the following example from the JAX quickstart.


Jit#

JAX uses the XLA compiler under the hood, and enables you to jit compile your code to make it faster and more efficient. This is the purpose of the @jit annotation.

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

v = random.normal(key, (1000000,))
%timeit selu(v).block_until_ready()
3.23 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Now using the jit annotation (or function here) to speed things up:

selu_jit = jax.jit(selu)
%timeit selu_jit(v).block_until_ready()
844 µs ± 8.07 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

jit compilation can be used along with autodiff in the code transparently.


Vectorization#

Finally, JAX enables you to write code that applies to a single example, and then vectorize it to manage transparently batching dimensions.

mat = random.normal(key, (15, 10))
batched_x = random.normal(key, (5, 10)) # Batch size on axis 0
single = random.normal(key, (10,))

def apply_matrix(v):
  return jnp.dot(mat, v)

print("Single apply shape: ", apply_matrix(single).shape)
print("Batched example shape: ", jax.vmap(apply_matrix)(batched_x).shape)
Single apply shape:  (15,)
Batched example shape:  (5, 15)

Full example: linear regression#

Let’s implement one of the simplest models using everything we have seen so far: a linear regression. From a set of data points \(\{(x_i,y_i), i\in \{1,\ldots, k\}, x_i\in\mathbb{R}^n,y_i\in\mathbb{R}^m\}\), we try to find a set of parameters \(W\in \mathcal{M}_{m,n}(\mathbb{R}), b\in\mathbb{R}^m\) such that the function \(f_{W,b}(x)=Wx+b\) minimizes the mean squared error: $\(\mathcal{L}(W,b)\rightarrow\frac{1}{k}\sum_{i=1}^{k} \frac{1}{2}\|y_i-f_{W,b}(x_i)\|^2_2\)$

(Note: depending on how you cast the regression problem you might end up with different setups. Theoretically we should be minimizing the expectation of the loss wrt to the data distribution, however for the sake of simplicity here we consider only the sampled loss).

# Linear feed-forward.
def predict(W, b, x):
  return jnp.dot(x, W) + b

# Loss function: Mean squared error.
def mse(W, b, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    y_pred = predict(W, b, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)
x shape: (20, 10) ; y shape: (20, 5)
# Initialize estimated W and b with zeros.
W_hat = jnp.zeros_like(W)
b_hat = jnp.zeros_like(b)

# Ensure we jit the largest-possible jittable block.
@jax.jit
def update_params(W, b, x, y, lr):
  W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)
  return W, b

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(W, b, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse(W_hat, b_hat, x_samples, y_samples))
Loss for "true" W,b:  0.023639796
Loss step 0:  10.971408
Loss step 5:  1.0798326
Loss step 10:  0.37958243
Loss step 15:  0.1785529
Loss step 20:  0.094415195
Loss step 25:  0.054522213
Loss step 30:  0.03448923
Loss step 35:  0.024058014
Loss step 40:  0.018480862
Loss step 45:  0.015438671
Loss step 50:  0.0137539385
Loss step 55:  0.012810304
Loss step 60:  0.0122773
Loss step 65:  0.011974386
Loss step 70:  0.011801454
Loss step 75:  0.011702402
Loss step 80:  0.011645537
Loss step 85:  0.011612833
Loss step 90:  0.011594019
Loss step 95:  0.011583149
Loss step 100:  0.011576906

This is obviously an approximate solution to the linear regression problem (solving it would require a bit more work!), but here you have all the tools you would need if you wanted to do it the proper way.

Refining a bit with pytrees#

Here we’re going to elaborate on our previous example using JAX pytree data structure.

Pytrees basics#

The JAX ecosystem uses pytrees everywhere and we do as well in Flax (the previous FrozenDict example is one, we’ll get back to this). For a complete overview, we suggest that you take a look at the pytree page from JAX’s doc:

In JAX, a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts (JAX can be extended to consider other container types as pytrees, see Extending pytrees below). A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.

[1, "a", object()] # 3 leaves: 1, "a" and object()

(1, (2, 3), ()) # 3 leaves: 1, 2 and 3

[1, {"k1": 2, "k2": (3, 4)}, 5] # 5 leaves: 1, 2, 3, 4, 5

JAX provides a few utilities to work with pytrees that live in the tree_util package.

from jax import tree_util

t = [1, {"k1": 2, "k2": (3, 4)}, 5]

You will often come across tree_map function that maps a function f to a tree and its leaves. We used it in the previous section to display the shapes of the model’s parameters.

tree_util.tree_map(lambda x: x*x, t)
[1, {'k1': 4, 'k2': (9, 16)}, 25]

Instead of applying a standalone function to each of the tree leaves, you can also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function.

t2 = tree_util.tree_map(lambda x: x*x, t)
tree_util.tree_map(lambda x,y: x+y, t, t2)
[2, {'k1': 6, 'k2': (12, 20)}, 30]

Linear regression with Pytrees#

Whereas our previous example was perfectly fine, we can see that when things get more complicated (as they will with neural networks), it will be harder to manage parameters of the models as we did.

Here we show an alternative based on pytrees, using the same data from the previous example. Now, our params is a pytree containing both the W and b entries.

# Linear feed-forward that takes a params pytree.
def predict_pytree(params, x):
  return jnp.dot(x, params['W']) + params['b']

# Loss function: Mean squared error.
def mse_pytree(params, x_batched,y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x,y):
    y_pred = predict_pytree(params, x)
    return jnp.inner(y-y_pred, y-y_pred) / 2.0
  # We vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

# Initialize estimated W and b with zeros. Store in a pytree.
params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}

The great thing is that JAX is able to handle differentiation with respect to pytree parameters:

jax.grad(mse_pytree)(params, x_samples, y_samples)
{'W': DeviceArray([[-1.1769425 ,  0.39893118, -0.89776236,  1.2179708 ,
               -0.22526088],
              [-1.2351222 , -0.51718676, -1.2093343 , -0.4701616 ,
               -0.14098746],
              [ 1.8270822 ,  0.44839874, -2.1687849 , -0.5517593 ,
                1.203793  ],
              [ 2.47367   ,  1.1133709 , -0.5650934 , -1.0506653 ,
                1.2003354 ],
              [ 0.09737853, -1.3749819 ,  2.8142433 , -0.15796328,
                1.1032437 ],
              [-1.2241135 ,  0.63864106,  0.56382763, -0.2598635 ,
               -0.62159854],
              [ 0.13155793, -0.36742553, -2.3217206 ,  1.4004409 ,
               -0.86137396],
              [-1.2007939 ,  2.207632  , -1.8339298 , -0.3254988 ,
                0.39060313],
              [-0.81191576,  0.5818832 , -1.5981957 ,  1.7225316 ,
                0.07298003],
              [ 0.7984328 ,  0.17107074,  0.79715335,  0.5209309 ,
               -0.23226593]], dtype=float32),
 'b': DeviceArray([ 1.864018  ,  2.69724   , -3.1244898 , -1.4062692 ,
               0.66766924], dtype=float32)}

Now using our tree of params, we can write the gradient descent in a simpler way using jax.tree_map:

# Always remember to jit!
@jax.jit
def update_params_pytree(params, learning_rate, x_samples, y_samples):
  params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params,
        jax.grad(mse_pytree)(params, x_samples, y_samples))
  return params

learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse_pytree({'W': W, 'b': b}, x_samples, y_samples))
for i in range(101):
  # Perform one gradient update.
  params = update_params_pytree(params, learning_rate, x_samples, y_samples)
  if (i % 5 == 0):
    print(f"Loss step {i}: ", mse_pytree(params, x_samples, y_samples))
Loss for "true" W,b:  0.023639796
Loss step 0:  10.971408
Loss step 5:  1.0798326
Loss step 10:  0.37958243
Loss step 15:  0.1785529
Loss step 20:  0.094415195
Loss step 25:  0.054522213
Loss step 30:  0.03448923
Loss step 35:  0.024058014
Loss step 40:  0.018480862
Loss step 45:  0.015438671
Loss step 50:  0.0137539385
Loss step 55:  0.012810304
Loss step 60:  0.0122773
Loss step 65:  0.011974386
Loss step 70:  0.011801454
Loss step 75:  0.011702402
Loss step 80:  0.011645537
Loss step 85:  0.011612833
Loss step 90:  0.011594019
Loss step 95:  0.011583149
Loss step 100:  0.011576906

Besides jax.grad(), another useful function is jax.value_and_grad(), which returns the value of the input function and of its gradient.

To switch from jax.grad() to jax.value_and_grad(), replace the training loop above with the following:

# Using jax.value_and_grad instead:
loss_grad_fn = jax.value_and_grad(mse_pytree)
for i in range(101):
  # Note that here the loss is computed before the param update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = jax.tree_map(
        lambda p, g: p - learning_rate * g, params, grads)
    if (i % 5 == 0):
        print(f"Loss step {i}: ", loss_val)
Loss step 0:  0.011576906
Loss step 5:  0.011573283
Loss step 10:  0.011571205
Loss step 15:  0.011570008
Loss step 20:  0.011569307
Loss step 25:  0.011568909
Loss step 30:  0.01156868
Loss step 35:  0.011568547
Loss step 40:  0.011568466
Loss step 45:  0.011568422
Loss step 50:  0.011568392
Loss step 55:  0.011568381
Loss step 60:  0.011568381
Loss step 65:  0.011568381
Loss step 70:  0.011568359
Loss step 75:  0.011568369
Loss step 80:  0.011568364
Loss step 85:  0.01156837
Loss step 90:  0.011568367
Loss step 95:  0.011568364
Loss step 100:  0.011568364

That’s all you needed to know to get started with Flax! To dive deeper, we very much recommend checking the JAX docs.