{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "SwtfSYdoHsc_" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/guides/jax_for_the_impatient.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/guides/jax_for_the_impatient.ipynb)\n", "\n", "# JAX for the Impatient\n", "**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**\n", "\n", "Here we will cover the basics of JAX so that you can get started with Flax, however we very much recommend that you go through JAX's documentation [here](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) after going over the basics here." ] }, { "cell_type": "markdown", "metadata": { "id": "gF2oOT78zOIr" }, "source": [ "## NumPy API\n", "\n", "Let's start by exploring the NumPy API coming from JAX and the main differences you should be aware of." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "5csM8DZYEqk6" }, "outputs": [], "source": [ "import jax\n", "from jax import numpy as jnp, random\n", "\n", "import numpy as np # We import the standard NumPy library" ] }, { "cell_type": "markdown", "metadata": { "id": "Z5BLL6v_JUSI" }, "source": [ "`jax.numpy` is the NumPy-like API that needs to be imported, and we will also use `jax.random` to generate some data to work on.\n", "\n", "Let's start by generating some matrices, and then try matrix multiplication." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "L2HKiLTNJ4Eh", "outputId": "c4297a1a-4e4b-4bdc-ca5d-3d33aca92b3b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "DeviceArray([[1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.],\n", " [1., 1., 1., 1.]], dtype=float32)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = jnp.ones((4,4)) # We're generating one 4 by 4 matrix filled with ones.\n", "n = jnp.array([[1.0, 2.0, 3.0, 4.0],\n", " [5.0, 6.0, 7.0, 8.0]]) # An explicit 2 by 4 array\n", "m" ] }, { "cell_type": "markdown", "metadata": { "id": "NKFtn4d_Nu07" }, "source": [ "Arrays in JAX are represented as DeviceArray instances and are agnostic to the place where the array lives (CPU, GPU, or TPU). This is why we're getting the warning that no GPU/TPU was found and JAX is falling back to a CPU (unless you're running it in an environment that has a GPU/TPU available).\n", "\n", "We can obviously multiply matrices like we would do in NumPy." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9do-ZRGaRThn", "outputId": "9c4feb4d-3bd1-4921-97ce-c8087b37496f" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[10., 10., 10., 10.],\n", " [26., 26., 26., 26.]], dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jnp.dot(n, m).block_until_ready() # Note: yields the same result as np.dot(m)" ] }, { "cell_type": "markdown", "metadata": { "id": "Jkyt5xXpRidn" }, "source": [ "DeviceArray instances are actually futures ([more here](https://jax.readthedocs.io/en/latest/async_dispatch.html)) due to the **default asynchronous execution** in JAX. For that reason, the Python call might return before the computation actually ends, hence we're using the `block_until_ready()` method to ensure we return the end result.\n", "\n", "JAX is fully compatible with NumPy, and can transparently process arrays from one library to the other." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hFthGlHoRZ59", "outputId": "15892d6a-c06c-4f98-a7d4-ad432bdd1f57" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[-0.8318497, -0.8318497, -0.8318497, -0.8318497],\n", " [ 2.4768949, 2.4768949, 2.4768949, 2.4768949],\n", " [-1.0424521, -1.0424521, -1.0424521, -1.0424521],\n", " [-3.4560933, -3.4560933, -3.4560933, -3.4560933]], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.random.normal(size=(4,4)) # Creating one standard NumPy array instance\n", "jnp.dot(x,m)" ] }, { "cell_type": "markdown", "metadata": { "id": "AoaA-FS2XpsC" }, "source": [ "If you're using accelerators, using NumPy arrays directly will result in multiple transfers from CPU to GPU/TPU memory. You can save that transfer bandwidth, either by creating directly a DeviceArray or by using `jax.device_put` on the NumPy array. With DeviceArrays, computation is done on device so no additional data transfer is required, e.g. `jnp.dot(long_vector, long_vector)` will only transfer a single scalar (result of the computation) back from device to host." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-VABtdIwTFfN", "outputId": "08965869-bdd7-44c8-ae46-207061b5112c" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 0.08149499, 0.07987174, 1.1451471 , -0.59535813],\n", " [ 0.86550283, 0.6078417 , 0.7539637 , 1.5923587 ],\n", " [ 0.8374219 , -0.07827665, 1.4592382 , 1.4161737 ],\n", " [ 0.37525675, -0.8032943 , 2.062778 , -0.15352985]], dtype=float32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.random.normal(size=(4,4))\n", "x = jax.device_put(x)\n", "x" ] }, { "cell_type": "markdown", "metadata": { "id": "y_2QavY1tR8j" }, "source": [ "Conversely, if you want to get back a Numpy array from a JAX array, you can simply do so by using it in the Numpy API." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vEJ1mSvStjEC", "outputId": "00a8cc38-59a2-4cf9-ed23-eb5fbb708495" }, "outputs": [ { "data": { "text/plain": [ "array([[1., 2., 3., 4.],\n", " [5., 6., 7., 8.]], dtype=float32)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = jnp.array([[1.0, 2.0, 3.0, 4.0],\n", " [5.0, 6.0, 7.0, 8.0]])\n", "np.array(x)" ] }, { "cell_type": "markdown", "metadata": { "id": "CBHVd3GTpLKD" }, "source": [ "## (Im)mutability\n", "JAX is functional by essence, one practical consequence being that JAX arrays are immutable. This means no in-place ops and sliced assignments. More generally, functions should not take input or produce output using a global state." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-erZrgZXawFW", "outputId": "c3c03081-6235-482f-a88c-cc180f661954" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x: \n", " [[1. 2. 3. 4.]\n", " [5. 6. 7. 8.]]\n", "updated: \n", " [[3. 2. 3. 4.]\n", " [5. 6. 7. 8.]]\n" ] } ], "source": [ "x = jnp.array([[1.0, 2.0, 3.0, 4.0],\n", " [5.0, 6.0, 7.0, 8.0]])\n", "updated = x.at[0, 0].set(3.0) # whereas x[0,0] = 3.0 would fail\n", "print(\"x: \\n\", x) # Note that x didn't change, no in-place mutation.\n", "print(\"updated: \\n\", updated)" ] }, { "cell_type": "markdown", "metadata": { "id": "Sz_9b-XUTjjl" }, "source": [ "All jax ops are available with this syntax, including: `set`, `add`, `mul`, `min`, `max`." ] }, { "cell_type": "markdown", "metadata": { "id": "o8QGdusyzbmP" }, "source": [ "## Managing randomness\n", "In JAX, randomness is managed in a very specific way, and you can read more on JAX's docs [here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers) and [here](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html) (we borrow content from there!). As the JAX team puts it:\n", "\n", "*JAX implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating a PRNG state. JAX uses a modern Threefry counter-based PRNG that’s splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.*\n", "\n", "In short, you need to explicitly manage the PRNGs (pseudo random number generators) and their states. In JAX's PRNGs, the state is represented as a pair of two unsigned-int32s that is called a key (there is no special meaning to the two unsigned int32s -- it's just a way of representing a uint64)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8iz9KGF4s7nN", "outputId": "c5bb1581-090b-42ed-cc42-08436154bc14" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([0, 0], dtype=uint32)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "key = random.PRNGKey(0)\n", "key" ] }, { "cell_type": "markdown", "metadata": { "id": "1y622foIaYjL" }, "source": [ "If you use this key multiple times, you'll get the same \"random\" output each time. To generate further entries in the sequence, you'll need to split the PRNG and thus generate a new pair of keys." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Printing the random number using key: [0 0] gives: [-0.20584235]\n", "Printing the random number using key: [0 0] gives: [-0.20584235]\n", "Printing the random number using key: [0 0] gives: [-0.20584235]\n" ] } ], "source": [ "for i in range(3):\n", " print(\"Printing the random number using key: \", key, \" gives: \", random.normal(key,shape=(1,))) # Boringly not that random since we use the same key" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "lOBv5CaB3dMa", "outputId": "ac89afdc-a73e-4c31-d005-7e1e6ad551cd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "old key [0 0] --> normal [-0.20584235]\n", " \\---SPLIT --> new key [4146024105 967050713] --> normal [0.14389044]\n", " \\--> new subkey [2718843009 1272950319] --> normal [-1.2515389]\n" ] } ], "source": [ "print(\"old key\", key, \"--> normal\", random.normal(key, shape=(1,)))\n", "key, subkey = random.split(key)\n", "print(\" \\---SPLIT --> new key \", key, \"--> normal\", random.normal(key, shape=(1,)) )\n", "print(\" \\--> new subkey\", subkey, \"--> normal\", random.normal(subkey, shape=(1,)) )" ] }, { "cell_type": "markdown", "metadata": { "id": "QgCCZtyQ4EqA" }, "source": [ "You can also generate multiple subkeys at once if needed:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "G3zRojMs4Cce", "outputId": "e48e1ed0-4f16-49cb-dc2b-cb51d3ec56b5" }, "outputs": [ { "data": { "text/plain": [ "(array([3306097435, 3899823266], dtype=uint32),\n", " [array([147607341, 367236428], dtype=uint32),\n", " array([2280136339, 1907318301], dtype=uint32),\n", " array([ 781391491, 1939998335], dtype=uint32)])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "key, *subkeys = random.split(key, 4)\n", "key, subkeys" ] }, { "cell_type": "markdown", "metadata": { "id": "20lC7np5YKDq" }, "source": [ "You can think about those PRNGs as trees of keys that match the structure of your models, which is important for reproducibility and soundness of the random behavior that you expect." ] }, { "cell_type": "markdown", "metadata": { "id": "GC6-1gq1YsgZ" }, "source": [ "## Gradients and autodiff\n", "\n", "For a full overview of JAX's automatic differentiation system, you can check the [Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html).\n", "\n", "Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix. You can read more in the JAX [cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobian-vector-products-jvps-aka-forward-mode-autodiff) [notebook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#vector-jacobian-products-vjps-aka-reverse-mode-autodiff) mentioned above." ] }, { "cell_type": "markdown", "metadata": { "id": "CUFwVnn4011l" }, "source": [ "### Gradients\n", "\n", "JAX provides first-class support for gradients and automatic differentiation in functions. This is also where the functional paradigm shines, since gradients on functions are essentially stateless operations. If we consider a simple function $f:\\mathbb{R}^n\\rightarrow\\mathbb{R}$\n", "\n", "$$f(x) = \\frac{1}{2} x^T x$$\n", "\n", "with the (known) gradient:\n", "\n", "$$\\nabla f(x) = x$$" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zDOydrLMcIzp", "outputId": "580c14ed-d1a3-4f92-c9b9-78d58c87bc76" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(2., dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "key = random.PRNGKey(0)\n", "def f(x):\n", " return jnp.dot(x.T,x)/2.0\n", "\n", "v = jnp.ones((4,))\n", "f(v)" ] }, { "cell_type": "markdown", "metadata": { "id": "zVaiZplShoBK" }, "source": [ "JAX computes the gradient as an operator acting on functions with `jax.grad`. Note that this only works for scalar valued functions.\n", "\n", "Let's take the gradient of f and make sure it matches the identity map." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ael3pVHmhhTs", "outputId": "4d0c5122-1ead-4a94-9153-7eb3b399dae2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original v:\n", "[ 1.8160859 -0.7548852 0.33988902 -0.5348355 ]\n", "Gradient of f taken at point v\n", "[ 1.8160859 -0.7548852 0.33988902 -0.5348355 ]\n" ] } ], "source": [ "v = random.normal(key,(4,))\n", "print(\"Original v:\")\n", "print(v)\n", "print(\"Gradient of f taken at point v\")\n", "print(jax.grad(f)(v)) # should be equal to v !" ] }, { "cell_type": "markdown", "metadata": { "id": "UHIMfchIiQMR" }, "source": [ "As previously mentioned, `jax.grad` only works for scalar-valued functions. JAX can also handle general vector valued functions. The most useful primitives are a Jacobian-Vector product - `jax.jvp` - and a Vector-Jacobian product - `jax.vjp`.\n", "\n", "### Jacobian-Vector product\n", "\n", "Let's consider a map $f:\\mathbb{R}^n\\rightarrow\\mathbb{R}^m$. As a reminder, the differential of f is the map $df:\\mathbb{R}^n \\rightarrow \\mathcal{L}(\\mathbb{R}^n,\\mathbb{R}^m)$ where $\\mathcal{L}(\\mathbb{R}^n,\\mathbb{R}^m)$ is the space of linear maps from $\\mathbb{R}^n$ to $\\mathbb{R}^m$ (hence $df(x)$ is often represented as a Jacobian matrix). The linear approximation of f at point $x$ reads:\n", "\n", "$$f(x+v) = f(x) + df(x)\\bullet v + o(v)$$\n", "\n", "The $\\bullet$ operator means you are applying the linear map $df(x)$ to the vector v.\n", "\n", "Even though you are rarely interested in computing the full Jacobian matrix representing the linear map $df(x)$ in a standard basis, you are often interested in the quantity $df(x)\\bullet v$. This is exactly what `jax.jvp` is for, and `jax.jvp(f, (x,), (v,))` returns the tuple:\n", "\n", "$$(f(x), df(x)\\bullet v)$$" ] }, { "cell_type": "markdown", "metadata": { "id": "F5nI_gbeqj2y" }, "source": [ "Let's use a simple function as an example: $f(x) = \\frac{1}{2}({x_1}^2, {x_2}^2, \\ldots, {x_n}^2)$ where we know that $df(x)\\bullet h = (x_1h_1, x_2h_2,\\ldots,x_nh_n)$. Hence using `jax.jvp` with $h= (1,1,\\ldots,1)$ should return $x$ as an output." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Q2ntaHBeh-5u", "outputId": "93591ad3-832f-4928-c1f8-073cc3b7aae7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(x,f(x))\n", "(DeviceArray([ 0.18784378, -1.2833427 , -0.27109176, 1.2490592 ,\n", " 0.24446994], dtype=float32), DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278], dtype=float32))\n", "jax.jvp(f, (x,),(v,))\n", "(DeviceArray([0.01764264, 0.82348424, 0.03674537, 0.7800744 , 0.02988278], dtype=float32), DeviceArray([ 0.18784378, -1.2833427 , -0.27109176, 1.2490592 ,\n", " 0.24446994], dtype=float32))\n" ] } ], "source": [ "def f(x):\n", " return jnp.multiply(x,x)/2.0\n", "\n", "x = random.normal(key, (5,))\n", "v = jnp.ones(5)\n", "print(\"(x,f(x))\")\n", "print((x,f(x)))\n", "print(\"jax.jvp(f, (x,),(v,))\")\n", "print(jax.jvp(f, (x,),(v,)))" ] }, { "cell_type": "markdown", "metadata": { "id": "gdm_TTDLal_X" }, "source": [ "### Vector-Jacobian product\n", "Keeping our $f:\\mathbb{R}^n\\rightarrow\\mathbb{R}^m$ it's often the case (for example, when you are working with a scalar loss function) that you are interested in the composition $x\\rightarrow\\phi\\circ f(x)$ where $\\phi :\\mathbb{R}^m\\rightarrow\\mathbb{R}$. In that case, the gradient reads:\n", "\n", "$$\\nabla(\\phi\\circ f)(x) = J_f(x)^T\\nabla\\phi(f(x))$$\n", "\n", "Where $J_f(x)$ is the Jacobian matrix of f evaluated at x, meaning that $df(x)\\bullet v = J_f(x)v$.\n", "\n", "`jax.vjp(f,x)` returns the tuple:\n", "\n", "$$(f(x),v\\rightarrow v^TJ_f(x))$$\n", "\n", "Keeping the same example as previously, using $v=(1,\\ldots,1)$, applying the VJP function returned by JAX should return the $x$ value:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "_1VTl9zXqsFl", "outputId": "f3f143a9-b1f1-4a4d-e4b1-c24a0fa114b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x = [ 0.18784378 -1.2833427 -0.27109176 1.2490592 0.24446994]\n", "v^T Jf(x) = [ 0.18784378 -1.2833427 -0.27109176 1.2490592 0.24446994]\n" ] } ], "source": [ "(val, jvp_fun) = jax.vjp(f,x)\n", "print(\"x = \", x)\n", "print(\"v^T Jf(x) = \", jvp_fun(jnp.ones((5,)))[0])" ] }, { "cell_type": "markdown", "metadata": { "id": "2v1Uq_XlzRZS" }, "source": [ "## Accelerating code with jit & ops vectorization\n", "We borrow the following example from the [JAX quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html).\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": { "id": "kF04t9L71dhH" }, "source": [ "### Jit\n", "\n", "JAX uses the XLA compiler under the hood, and enables you to jit compile your code to make it faster and more efficient. This is the purpose of the @jit annotation." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "D6p_wQ9xeIiu", "outputId": "af7ea5af-5ee1-4aa5-d8d7-8f6a20da2b0e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.96 ms ± 86.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "def selu(x, alpha=1.67, lmbda=1.05):\n", " return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n", "\n", "v = random.normal(key, (1000000,))\n", "%timeit selu(v).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "Nk9LVX580j6M" }, "source": [ "Now using the jit annotation (or function here) to speed things up:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "us5pWySG0jWL", "outputId": "e8ff3b7b-3917-40fc-8f29-eb9e6df262e5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "405 µs ± 32.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], "source": [ "selu_jit = jax.jit(selu)\n", "%timeit selu_jit(v).block_until_ready()" ] }, { "cell_type": "markdown", "metadata": { "id": "6kQyCgo407oF" }, "source": [ "jit compilation can be used along with autodiff in the code transparently.\n", "\n", "---\n", "### Vectorization\n", "\n", "Finally, JAX enables you to write code that applies to a single example, and then vectorize it to manage transparently batching dimensions." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "j-E6MsKF0tmZ", "outputId": "bfa377e8-92ee-4473-abd4-8d52338e2cc5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Single apply shape: (15,)\n", "Batched example shape: (5, 15)\n" ] } ], "source": [ "mat = random.normal(key, (15, 10))\n", "batched_x = random.normal(key, (5, 10)) # Batch size on axis 0\n", "single = random.normal(key, (10,))\n", "\n", "def apply_matrix(v):\n", " return jnp.dot(mat, v)\n", "\n", "print(\"Single apply shape: \", apply_matrix(single).shape)\n", "print(\"Batched example shape: \", jax.vmap(apply_matrix)(batched_x).shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "S2BcA8wm2_FW" }, "source": [ "## Full example: linear regression\n", "\n", "Let's implement one of the simplest models using everything we have seen so far: a linear regression. From a set of data points $\\{(x_i,y_i), i\\in \\{1,\\ldots, k\\}, x_i\\in\\mathbb{R}^n,y_i\\in\\mathbb{R}^m\\}$, we try to find a set of parameters $W\\in \\mathcal{M}_{m,n}(\\mathbb{R}), b\\in\\mathbb{R}^m$ such that the function $f_{W,b}(x)=Wx+b$ minimizes the mean squared error:\n", "\n", "$$\\mathcal{L}(W,b)\\rightarrow\\frac{1}{k}\\sum_{i=1}^{k} \\frac{1}{2}\\|y_i-f_{W,b}(x_i)\\|^2_2$$\n", "\n", "(Note: depending on how you cast the regression problem you might end up with different setups. Theoretically we should be minimizing the expectation of the loss wrt to the data distribution, however for the sake of simplicity here we consider only the sampled loss)." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "5W9p_zVe2Cj-" }, "outputs": [], "source": [ "# Linear feed-forward.\n", "def predict(W, b, x):\n", " return jnp.dot(x, W) + b\n", "\n", "# Loss function: Mean squared error.\n", "def mse(W, b, x_batched, y_batched):\n", " # Define the squared loss for a single pair (x,y)\n", " def squared_error(x, y):\n", " y_pred = predict(W, b, x)\n", " return jnp.inner(y-y_pred, y-y_pred) / 2.0\n", " # We vectorize the previous to compute the average of the loss on all samples.\n", " return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "qMkIxjjsduPY" }, "outputs": [], "source": [ "# Set problem dimensions.\n", "n_samples = 20\n", "x_dim = 10\n", "y_dim = 5\n", "\n", "# Generate random ground truth W and b.\n", "key = random.PRNGKey(0)\n", "k1, k2 = random.split(key)\n", "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", "\n", "# Generate samples with additional noise.\n", "key_sample, key_noise = random.split(k1)\n", "x_samples = random.normal(key_sample, (n_samples, x_dim))\n", "y_samples = predict(W, b, x_samples) + 0.1 * random.normal(key_noise,(n_samples, y_dim))\n", "print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5L2np6wve_xp", "outputId": "9db5c834-d7da-4291-d1ec-d4c39008d5ed" }, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mRunning cells with 'Python 3.7.3 64-bit' requires ipykernel package.\n", "Run the following command to install 'ipykernel' into the Python environment. \n", "Command: '/usr/bin/python3 -m pip install ipykernel -U --user --force-reinstall'" ] } ], "source": [ "# Initialize estimated W and b with zeros.\n", "W_hat = jnp.zeros_like(W)\n", "b_hat = jnp.zeros_like(b)\n", "\n", "# Ensure we jit the largest-possible jittable block.\n", "@jax.jit\n", "def update_params(W, b, x, y, lr):\n", " W, b = W - lr * jax.grad(mse, 0)(W, b, x, y), b - lr * jax.grad(mse, 1)(W, b, x, y)\n", " return W, b\n", "\n", "learning_rate = 0.3 # Gradient step size.\n", "print('Loss for \"true\" W,b: ', mse(W, b, x_samples, y_samples))\n", "for i in range(101):\n", " # Perform one gradient update.\n", " W_hat, b_hat = update_params(W_hat, b_hat, x_samples, y_samples, learning_rate)\n", " if (i % 5 == 0):\n", " print(f\"Loss step {i}: \", mse(W_hat, b_hat, x_samples, y_samples))" ] }, { "cell_type": "markdown", "metadata": { "id": "bJGKunxNzrxa" }, "source": [ "This is obviously an approximate solution to the linear regression problem (solving it would require a bit more work!), but here you have all the tools you would need if you wanted to do it the proper way." ] }, { "cell_type": "markdown", "metadata": { "id": "bQXmL86aUS9x" }, "source": [ "## Refining a bit with pytrees\n", "\n", "Here we're going to elaborate on our previous example using JAX pytree data structure." ] }, { "cell_type": "markdown", "metadata": { "id": "zZMUvyCgUzby" }, "source": [ "### Pytrees basics\n", "\n", "The JAX ecosystem uses pytrees everywhere and we do as well in Flax (the previous FrozenDict example is one, we'll get back to this). For a complete overview, we suggest that you take a look at the [pytree page](https://jax.readthedocs.io/en/latest/pytrees.html) from JAX's doc:\n", "\n", "*In JAX, a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts (JAX can be extended to consider other container types as pytrees, see Extending pytrees below). A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.*\n", "\n", "```python\n", "[1, \"a\", object()] # 3 leaves: 1, \"a\" and object()\n", "\n", "(1, (2, 3), ()) # 3 leaves: 1, 2 and 3\n", "\n", "[1, {\"k1\": 2, \"k2\": (3, 4)}, 5] # 5 leaves: 1, 2, 3, 4, 5\n", "```\n", "\n", "JAX provides a few utilities to work with pytrees that live in the `tree_util` package." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "9SNY5eA1UdkJ" }, "outputs": [], "source": [ "from jax import tree_util\n", "\n", "t = [1, {\"k1\": 2, \"k2\": (3, 4)}, 5]" ] }, { "cell_type": "markdown", "metadata": { "id": "LujWjwVQUeea" }, "source": [ "You will often come across `tree_map` function that maps a function f to a tree and its leaves. We used it in the previous section to display the shapes of the model's parameters." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "szDhssVBUjTa", "outputId": "9ae4ebf1-a3c4-4ecb-b3df-67c8450310f8" }, "outputs": [ { "data": { "text/plain": [ "[1, {'k1': 4, 'k2': (9, 16)}, 25]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree_util.tree_map(lambda x: x*x, t)" ] }, { "cell_type": "markdown", "metadata": { "id": "3s167WGKUlZ9" }, "source": [ "Instead of applying a standalone function to each of the tree leaves, you can also provide a tuple of additional trees with similar shape to the input tree that will provide per leaf arguments to the function." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bNOYK_E7UnOh", "outputId": "d211bf85-5993-488c-9fec-aeaf375df007" }, "outputs": [ { "data": { "text/plain": [ "[2, {'k1': 6, 'k2': (12, 20)}, 30]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t2 = tree_util.tree_map(lambda x: x*x, t)\n", "tree_util.tree_map(lambda x,y: x+y, t, t2)" ] }, { "cell_type": "markdown", "metadata": { "id": "HnE75pvlVDO5" }, "source": [ "### Linear regression with Pytrees\n", "\n", "Whereas our previous example was perfectly fine, we can see that when things get more complicated (as they will with neural networks), it will be harder to manage parameters of the models as we did.\n", "\n", "Here we show an alternative based on pytrees, using the same data from the previous example.\n", "Now, our `params` is a pytree containing both the `W` and `b` entries." ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "id": "8v8gNkvUVZnl" }, "outputs": [], "source": [ "# Linear feed-forward that takes a params pytree.\n", "def predict_pytree(params, x):\n", " return jnp.dot(x, params['W']) + params['b']\n", "\n", "# Loss function: Mean squared error.\n", "def mse_pytree(params, x_batched,y_batched):\n", " # Define the squared loss for a single pair (x,y)\n", " def squared_error(x,y):\n", " y_pred = predict_pytree(params, x)\n", " return jnp.inner(y-y_pred, y-y_pred) / 2.0\n", " # We vectorize the previous to compute the average of the loss on all samples.\n", " return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)\n", "\n", "# Initialize estimated W and b with zeros. Store in a pytree.\n", "params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}" ] }, { "cell_type": "markdown", "metadata": { "id": "rKP0X8rnWAiA" }, "source": [ "The great thing is that JAX is able to handle differentiation with respect to pytree parameters:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8zc7cMaiWSny", "outputId": "a69605cb-1eed-4f81-fc2e-93646c9694dd" }, "outputs": [ { "data": { "text/plain": [ "{'W': DeviceArray([[-1.9287349e+00, 4.2963755e-01, 7.1613449e-01,\n", " 2.1056123e+00, 5.0405121e-01, -2.4983375e+00,\n", " -6.3854176e-01, -2.2620213e+00, -1.3365206e+00,\n", " -2.0426039e-01],\n", " [ 1.1999468e+00, -9.4563609e-01, -1.0878400e+00,\n", " -7.0340711e-01, 3.3224609e-01, 1.7538791e+00,\n", " -7.1916544e-01, 1.0927428e+00, -1.4491037e+00,\n", " 5.9715635e-01],\n", " [-1.4826509e+00, -7.6116532e-01, 2.2319858e-01,\n", " -3.0391946e-01, 3.0397055e+00, -3.8419428e-01,\n", " -1.8290073e+00, -2.3353369e+00, -1.1087127e+00,\n", " -7.7453995e-01],\n", " [ 8.2374442e-01, -9.9650609e-01, -7.6030111e-01,\n", " 6.3919222e-01, -6.0864899e-02, -1.0859716e+00,\n", " 1.2923398e+00, -4.9342898e-01, -1.4711156e-03,\n", " 1.2977618e+00],\n", " [-4.5656446e-01, -1.3063025e-01, -3.9179009e-01,\n", " 2.1743817e+00, -5.3948693e-02, 4.5653123e-01,\n", " -8.5279423e-01, 1.1709594e+00, 9.6438813e-01,\n", " -2.3813749e-02]], dtype=float32),\n", " 'b': DeviceArray([ 1.0923628, 1.3121076, -2.9304824, -0.6492362, 1.1531248], dtype=float32)}" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.grad(mse_pytree)(params, x_samples, y_samples)" ] }, { "cell_type": "markdown", "metadata": { "id": "nW1IKnjqXFdN" }, "source": [ "Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_util.tree_map`:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "jEntdcDBXBCj", "outputId": "f309aff7-2aad-453f-ad88-019d967d4289" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loss for \"true\" W,b: 0.023639774\n", "Loss step 0: 11.096583\n", "Loss step 5: 1.1743388\n", "Loss step 10: 0.32879353\n", "Loss step 15: 0.1398177\n", "Loss step 20: 0.07359565\n", "Loss step 25: 0.04415301\n", "Loss step 30: 0.029408678\n", "Loss step 35: 0.021554656\n", "Loss step 40: 0.017227933\n", "Loss step 45: 0.014798875\n", "Loss step 50: 0.013420242\n", "Loss step 55: 0.0126327025\n", "Loss step 60: 0.0121810865\n", "Loss step 65: 0.011921468\n", "Loss step 70: 0.011771992\n", "Loss step 75: 0.011685831\n", "Loss step 80: 0.011636148\n", "Loss step 85: 0.011607475\n", "Loss step 90: 0.011590928\n", "Loss step 95: 0.011581394\n", "Loss step 100: 0.011575883\n" ] } ], "source": [ "# Always remember to jit!\n", "@jax.jit\n", "def update_params_pytree(params, learning_rate, x_samples, y_samples):\n", " params = jax.tree_util.tree_map(\n", " lambda p, g: p - learning_rate * g, params,\n", " jax.grad(mse_pytree)(params, x_samples, y_samples))\n", " return params\n", "\n", "learning_rate = 0.3 # Gradient step size.\n", "print('Loss for \"true\" W,b: ', mse_pytree({'W': W, 'b': b}, x_samples, y_samples))\n", "for i in range(101):\n", " # Perform one gradient update.\n", " params = update_params_pytree(params, learning_rate, x_samples, y_samples)\n", " if (i % 5 == 0):\n", " print(f\"Loss step {i}: \", mse_pytree(params, x_samples, y_samples))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Besides `jax.grad()`, another useful function is `jax.value_and_grad()`, which returns the value of the input function and of its gradient.\n", "\n", "To switch from `jax.grad()` to `jax.value_and_grad()`, replace the training loop above with the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Using jax.value_and_grad instead:\n", "loss_grad_fn = jax.value_and_grad(mse_pytree)\n", "for i in range(101):\n", " # Note that here the loss is computed before the param update.\n", " loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n", " params = jax.tree_util.tree_map(\n", " lambda p, g: p - learning_rate * g, params, grads)\n", " if (i % 5 == 0):\n", " print(f\"Loss step {i}: \", loss_val)" ] }, { "cell_type": "markdown", "metadata": { "id": "Xh-oo8jFUPNQ" }, "source": [ "That's all you needed to know to get started with Flax! To dive deeper, we very much recommend checking the JAX [docs](https://jax.readthedocs.io/en/latest/index.html)." ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "JAX for the impatient.ipynb", "provenance": [], "toc_visible": true }, "jupytext": { "formats": "ipynb,md:myst", "main_language": "python" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }