{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "C1QVJFlVsxcZ" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", "[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/notebooks/linen_intro.ipynb)\n", "\n", "# Preface\n", "\n", "
\n", "
CAVEAT PROGRAMMER
\n", "\n", "The below is an alpha API preview and things might break. The surface syntax of the features of the API are not fixed in stone, and we welcome feedback on any points." ] }, { "cell_type": "markdown", "metadata": { "id": "23zkGDayszYI" }, "source": [ "## Useful links\n", "\n", "⟶ [Slides](https://docs.google.com/presentation/d/1ngKWUwsSqAwPRvATG8sAxMzu9ujv4N__cKsUofdNno0/edit?usp=sharing) for the core ideas of the new Functional Core and Linen\n", "\n", "⟶ \"Design tests\" guided our design process. Many are available for [functional core](https://github.com/google/flax/tree/main/examples/core_design_test) and some for the [proposed Module abstraction](https://github.com/google/flax/tree/main/examples/linen_design_test/)\n", "\n", "⟶ Ported examples: [ImageNet](https://github.com/google/flax/tree/main/examples/imagenet) and [WMT](https://github.com/google/flax/tree/main/examples/wmt) (to the proposed Module abstraction). TODO: Port to functional core.\n", "\n", "⟶ Our new [discussion forums](https://github.com/google/flax/discussions/)" ] }, { "cell_type": "markdown", "metadata": { "id": "vGtC_5W4mQnY" }, "source": [ "# Install and Import" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HgRZ_G8wGcoB", "tags": [ "skip-execution" ] }, "outputs": [], "source": [ "# Install the newest JAXlib version.\n", "!pip install --upgrade -q pip jax jaxlib\n", "# Install Flax at head:\n", "!pip install --upgrade -q git+https://github.com/google/flax.git" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Kvx7GmavHZbD" }, "outputs": [], "source": [ "import functools\n", "from typing import Any, Callable, Sequence, Optional\n", "import jax\n", "from jax import lax, random, numpy as jnp\n", "from flax.core import freeze, unfreeze\n", "from flax import linen as nn" ] }, { "cell_type": "markdown", "metadata": { "id": "u86fYsrEfYow" }, "source": [ "# Invoking Modules" ] }, { "cell_type": "markdown", "metadata": { "id": "nrVbFrh1ffve" }, "source": [ "Let's instantiate a `Dense` layer.\n", " - Modules are actually objects in this API, so we provide _contructor arguments_ when initializing the Module. In this case, we only have to provide the output `features` dimension." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "EcDH20Uufc-v" }, "outputs": [], "source": [ "model = nn.Dense(features=3)" ] }, { "cell_type": "markdown", "metadata": { "id": "hL4NgtBwgI0S" }, "source": [ "We need to initialize the Module variables, these include the parameters of the Module as well as any other state variables.\n", "\n", "We call the `init` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `init` with `(rngs, *args, **kwargs)` so in this case, just `(rng, input)`:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Vjx0HWNcfa8h", "outputId": "3adfaeaf-977e-4e82-8adf-d254fae6eb91" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})" ] }, "execution_count": 4, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Make RNG Keys and a fake input.\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "# provide key and fake input to get initialized variables\n", "init_variables = model.init(key2, x)\n", "\n", "init_variables" ] }, { "cell_type": "markdown", "metadata": { "id": "ubFTzroGhErh" }, "source": [ "We call the `apply` method on the instantiated Module. If the Module `__call__` method has args `(self, *args, **kwargs)` then we call `apply` with `(variables, *args, rngs=, mutable=, **kwargs)` where\n", " - `` are the optional _call time_ RNGs for things like dropout. For simple Modules this is just a single key, but if your module has multiple __kinds__ of data, it's a dictionary of rng-keys per-kind, e.g. `{'params': key0, 'dropout': key1}` for a Module with dropout layers.\n", " - `` is an optional list of names of __kinds__ that are expected to be mutated during the call. e.g. `['batch_stats']` for a layer updating batchnorm statistics.\n", "\n", "So in this case, just `(variables, input)`:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "R9QZ6EOBg5X8", "outputId": "e8c389a6-29f3-4f93-97ea-703e85a8b811" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 0.5035518 , 1.8548559 , -0.4270196 ],\n", " [ 0.0279097 , 0.5589246 , -0.43061775],\n", " [ 0.35471284, 1.5741 , -0.3286552 ],\n", " [ 0.5264864 , 1.2928858 , 0.10089308]], dtype=float32)" ] }, "execution_count": 5, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "y = model.apply(init_variables, x)\n", "y" ] }, { "cell_type": "markdown", "metadata": { "id": "lNH06qc1hPrd" }, "source": [ "Additional points:\n", " - If you want to `init` or `apply` a Module using a method other than call, you need to provide the `method=` kwarg to `init` and `apply` to use it instead of the default `__call__`, e.g. `method='encode'`, `method='decode'` to apply the encode/decode methods of an autoencoder." ] }, { "cell_type": "markdown", "metadata": { "id": "jjsyiBjIYcAB" }, "source": [ "# Defining Basic Modules" ] }, { "cell_type": "markdown", "metadata": { "id": "UvU7416Ti_lR" }, "source": [ "## Composing submodules" ] }, { "cell_type": "markdown", "metadata": { "id": "LkTy0hmJdE5G" }, "source": [ "We support declaring modules in `setup()` that can still benefit from shape inference by using __Lazy Initialization__ that sets up variables the first time the Module is called." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qB6l-9EabOwH", "outputId": "1a6c6a17-0b95-42c2-b5bf-b9ad80fd7758", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", " -4.5417294e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", " -1.0110775e-03]]\n" ] } ], "source": [ "class ExplicitMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " def setup(self):\n", " # we automatically know what to do with lists, dicts of submodules\n", " self.layers = [nn.Dense(feat) for feat in self.features]\n", " # for single submodules, we would just write:\n", " # self.layer1 = nn.Dense(feat1)\n", "\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, lyr in enumerate(self.layers):\n", " x = lyr(x)\n", " if i != len(self.layers) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "slwE6ULqc_t_" }, "source": [ "Here we show the equivalent compact form of the MLP that declares the submodules inline using the `@compact` decorator." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UPNGIr6wcGaw", "outputId": "b3709789-e66e-4e20-f6b2-04022f8a62bb", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 4.2292815e-02 -4.3807115e-02 2.9323792e-02 6.5492536e-03\n", " -1.7147182e-02]\n", " [ 1.2967804e-01 -1.4551792e-01 9.4432175e-02 1.2521386e-02\n", " -4.5417294e-02]\n", " [ 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00\n", " 0.0000000e+00]\n", " [ 9.3024090e-04 2.7864411e-05 2.4478839e-04 8.1344356e-04\n", " -1.0110775e-03]]\n" ] } ], "source": [ "class SimpleMLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " # providing a name is optional though!\n", " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", " # x = nn.Dense(feat)(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "b2OzKXYyjFSf" }, "source": [ "## Declaring and using variables" ] }, { "cell_type": "markdown", "metadata": { "id": "uYwS5KbcmYIp" }, "source": [ "Flax uses lazy initialization, which allows declared variables to be initialized only at the first site of their use, using whatever shape information is available a the local call site for shape inference. Once a variable has been initialized, a reference to the data is kept for use in subsequent calls.\n", "\n", "For declaring parameters that aren't mutated inside the model, but rather by gradient descent, we use the syntax:\n", "\n", " `self.param(parameter_name, parameter_init_fn, *init_args)`\n", "\n", "with arguments:\n", " - `parameter_name` just the name, a string\n", " - `parameter_init_fn` a function taking an RNG key and a variable number of other arguments, i.e. `fn(rng, *args)`. typically those in `nn.initializers` take an `rng` and a `shape` argument.\n", " - the remaining arguments to feed to the init function when initializing.\n", "\n", "Again, we'll demonstrate declaring things inline as we typically do using the `@compact` decorator." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7OACbTFHjMvl", "outputId": "bc5cb1f2-c5e9-4159-d131-73247009e32f", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameters:\n", " FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "output:\n", " [[ 0.5035518 1.8548559 -0.4270196 ]\n", " [ 0.0279097 0.5589246 -0.43061775]\n", " [ 0.35471284 1.5741 -0.3286552 ]\n", " [ 0.5264864 1.2928858 0.10089308]]\n" ] } ], "source": [ "class SimpleDense(nn.Module):\n", " features: int\n", " kernel_init: Callable = nn.initializers.lecun_normal()\n", " bias_init: Callable = nn.initializers.zeros_init()\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " kernel = self.param('kernel',\n", " self.kernel_init, # RNG passed implicitly.\n", " (inputs.shape[-1], self.features)) # shape info.\n", " y = lax.dot_general(inputs, kernel,\n", " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", " bias = self.param('bias', self.bias_init, (self.features,))\n", " y = y + bias\n", " return y\n", "\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleDense(features=3)\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameters:\\n', init_variables)\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "KgEwkrkfdlt8" }, "source": [ "We can also declare variables in setup, though in doing so you can't take advantage of shape inference and have to provide explicit shape information at initialization. The syntax is a little repetitive in this case right now, but we do force agreement of the assigned names." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "CE0CTLVvZ8Yn", "outputId": "1e822bd8-7a08-4e80-e0e6-a86637c46772", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameters:\n", " FrozenDict({\n", " params: {\n", " kernel: DeviceArray([[ 0.6503669 , 0.8678979 , 0.46042678],\n", " [ 0.05673932, 0.9909285 , -0.63536596],\n", " [ 0.76134115, -0.3250529 , -0.6522163 ],\n", " [-0.8243032 , 0.4150194 , 0.19405058]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", "})\n", "output:\n", " [[ 0.5035518 1.8548559 -0.4270196 ]\n", " [ 0.0279097 0.5589246 -0.43061775]\n", " [ 0.35471284 1.5741 -0.3286552 ]\n", " [ 0.5264864 1.2928858 0.10089308]]\n" ] } ], "source": [ "class ExplicitDense(nn.Module):\n", " features_in: int # <-- explicit input shape\n", " features: int\n", " kernel_init: Callable = nn.initializers.lecun_normal()\n", " bias_init: Callable = nn.initializers.zeros_init()\n", "\n", " def setup(self):\n", " self.kernel = self.param('kernel',\n", " self.kernel_init,\n", " (self.features_in, self.features))\n", " self.bias = self.param('bias', self.bias_init, (self.features,))\n", "\n", " def __call__(self, inputs):\n", " y = lax.dot_general(inputs, self.kernel,\n", " (((inputs.ndim - 1,), (0,)), ((), ())),)\n", " y = y + self.bias\n", " return y\n", "\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitDense(features_in=4, features=3)\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameters:\\n', init_variables)\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "t4MVj1RBmxsZ" }, "source": [ "## General Variables" ] }, { "cell_type": "markdown", "metadata": { "id": "CJatarOTpByQ" }, "source": [ "For declaring generally mutable _variables_ that may be mutated inside the model we use the call:\n", "\n", " `self.variable(variable_kind, variable_name, variable_init_fn, *init_args)`\n", "\n", "with arguments:\n", " - `variable_kind` the \"kind\" of state this variable is, i.e. the name of the nested-dict collection that this will be stored in inside the top Modules variables. e.g. `batch_stats` for the moving statistics for a batch norm layer or `cache` for autoregressive cache data. Note that parameters also have a kind, but they're set to the default `param` kind.\n", " - `variable_name` just the name, a string\n", " - `variable_init_fn` a function taking a variable number of other arguments, i.e. `fn(*args)`. Note that we __don't__ assume the need for an RNG, if you _do_ want an RNG, provide it via a `self.make_rng(variable_kind)` call in the provided arguments.\n", " - the remaining arguments to feed to the init function when initializing.\n", "\n", "⚠️ Unlike parameters, we expect these to be mutated, so `self.variable` returns not a constant, but a _reference_ to the variable. To __get__ the raw value, you'd write `myvariable.value` and to __set__ it `myvariable.value = new_value`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u6_fbrW2XT5t", "outputId": "2a8f5453-81b1-44dc-a431-d14b372c5710", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized variables:\n", " FrozenDict({\n", " counter: {\n", " count: DeviceArray(0, dtype=int32),\n", " },\n", "})\n", "mutated variables:\n", " FrozenDict({\n", " counter: {\n", " count: DeviceArray(1, dtype=int32),\n", " },\n", "})\n", "output:\n", " 1\n" ] } ], "source": [ "class Counter(nn.Module):\n", " @nn.compact\n", " def __call__(self):\n", " # easy pattern to detect if we're initializing\n", " is_initialized = self.has_variable('counter', 'count')\n", " counter = self.variable('counter', 'count', lambda: jnp.zeros((), jnp.int32))\n", " if is_initialized:\n", " counter.value += 1\n", " return counter.value\n", "\n", "\n", "key1 = random.PRNGKey(0)\n", "\n", "model = Counter()\n", "init_variables = model.init(key1)\n", "print('initialized variables:\\n', init_variables)\n", "\n", "y, mutated_variables = model.apply(init_variables, mutable=['counter'])\n", "\n", "print('mutated variables:\\n', mutated_variables)\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "VLxwg2aMxUmy" }, "source": [ "## Another Mutability and RNGs Example" ] }, { "cell_type": "markdown", "metadata": { "id": "NOARPIowyeXS" }, "source": [ "Let's make an artificial, goofy example that mixes differentiable parameters, stochastic layers, and mutable variables:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "BBrbcEdCnQ4o", "outputId": "8f299a5c-74c8-476c-93fa-e5543901ec45", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "updated variables:\n", " FrozenDict({\n", " params: {\n", " Dense_0: {\n", " kernel: DeviceArray([[ 0.6498898 , -0.5000124 , 0.78573596],\n", " [-0.25609785, -0.7132329 , 0.2500864 ],\n", " [-0.64630085, 0.39321756, -1.0203307 ],\n", " [ 0.38721725, 0.86828285, 0.10860055]], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", " BatchNorm_0: {\n", " scale: DeviceArray([1., 1., 1.], dtype=float32),\n", " bias: DeviceArray([0., 0., 0.], dtype=float32),\n", " },\n", " },\n", " batch_stats: {\n", " BatchNorm_0: {\n", " mean: DeviceArray([ 0.00059601, -0.00103457, 0.00166948], dtype=float32),\n", " var: DeviceArray([0.9907686, 0.9923046, 0.992195 ], dtype=float32),\n", " },\n", " },\n", "})\n", "initialized variable shapes:\n", " FrozenDict({\n", " batch_stats: {\n", " BatchNorm_0: {\n", " mean: (3,),\n", " var: (3,),\n", " },\n", " },\n", " params: {\n", " BatchNorm_0: {\n", " bias: (3,),\n", " scale: (3,),\n", " },\n", " Dense_0: {\n", " bias: (3,),\n", " kernel: (4, 3),\n", " },\n", " },\n", "})\n", "output:\n", " [[[-0.21496922 0.21550177 -0.35633382]\n", " [-0.21496922 -2.0458 1.3015485 ]\n", " [-0.21496922 -0.925116 -0.35633382]\n", " [-0.6595459 0.21550177 0.3749205 ]]\n", "\n", " [[-0.21496922 1.642865 -0.35633382]\n", " [-0.21496922 1.3094063 -0.88034123]\n", " [ 2.5726683 0.21550177 0.34353197]\n", " [-0.21496922 0.21550177 1.6778195 ]]\n", "\n", " [[-1.6060593 0.21550177 -1.9460517 ]\n", " [ 1.4126908 -1.4898677 1.2790381 ]\n", " [-0.21496922 0.21550177 -0.35633382]\n", " [-0.21496922 0.21550177 -0.7251308 ]]]\n", "eval output:\n", " [[[ 3.2246590e-01 2.6108384e-02 4.4821960e-01]\n", " [ 8.5726947e-02 -5.4385906e-01 3.8821870e-01]\n", " [-2.3933809e-01 -2.7381191e-01 -1.7526165e-01]\n", " [-6.2515378e-02 -5.2414006e-01 1.7029770e-01]]\n", "\n", " [[ 1.5014435e-01 3.4498507e-01 -1.3554120e-01]\n", " [-3.6971044e-04 2.6463276e-01 -1.2491019e-01]\n", " [ 3.8763803e-01 2.9023719e-01 1.6291586e-01]\n", " [ 4.1320035e-01 4.1468274e-02 4.7670874e-01]]\n", "\n", " [[-1.9433719e-01 5.2831882e-01 -3.7554008e-01]\n", " [ 2.2608691e-01 -4.0989807e-01 3.8292480e-01]\n", " [-2.4945706e-01 1.6170470e-01 -2.5247774e-01]\n", " [-7.2220474e-02 1.2077977e-01 -8.8408351e-02]]]\n" ] } ], "source": [ "class Block(nn.Module):\n", " features: int\n", " training: bool\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = nn.Dense(self.features)(inputs)\n", " x = nn.Dropout(rate=0.5)(x, deterministic=not self.training)\n", " x = nn.BatchNorm(use_running_average=not self.training)(x)\n", " return x\n", "\n", "key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)\n", "x = random.uniform(key1, (3,4,4))\n", "\n", "model = Block(features=3, training=True)\n", "\n", "init_variables = model.init({'params': key2, 'dropout': key3}, x)\n", "_, init_params = init_variables.pop('params')\n", "\n", "# When calling `apply` with mutable kinds, returns a pair of output,\n", "# mutated_variables.\n", "y, mutated_variables = model.apply(\n", " init_variables, x, rngs={'dropout': key4}, mutable=['batch_stats'])\n", "\n", "# Now we reassemble the full variables from the updates (in a real training\n", "# loop, with the updated params from an optimizer).\n", "updated_variables = freeze(dict(params=init_params,\n", " **mutated_variables))\n", "\n", "print('updated variables:\\n', updated_variables)\n", "print('initialized variable shapes:\\n',\n", " jax.tree_util.tree_map(jnp.shape, init_variables))\n", "print('output:\\n', y)\n", "\n", "# Let's run these model variables during \"evaluation\":\n", "eval_model = Block(features=3, training=False)\n", "y = eval_model.apply(updated_variables, x) # Nothing mutable; single return value.\n", "print('eval output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "Lcp28h72810L" }, "source": [ "# JAX transformations inside modules" ] }, { "cell_type": "markdown", "metadata": { "id": "WEpbn8si0ATT" }, "source": [ "## JIT" ] }, { "cell_type": "markdown", "metadata": { "id": "-k-5gXTJ0EpD" }, "source": [ "It's not immediately clear what use this has, but you can compile specific submodules if there's a reason to.\n", "\n", "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing jitted an unjitted initializations will look different." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UEUTO8bf0Kf2", "outputId": "3f324d0f-259f-40f0-8273-103f7fc281c5", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[ 0.2524199 0.11621253 0.5246693 0.19144788 0.2096542 ]\n", " [ 0.08557513 -0.04126885 0.2502836 0.03910369 0.16575359]\n", " [ 0.2804383 0.27751124 0.44969672 0.26016283 0.05875347]\n", " [ 0.2440843 0.17069656 0.45499086 0.20377949 0.13428023]]\n" ] } ], "source": [ "class MLP(nn.Module):\n", " features: Sequence[int]\n", "\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " # JIT the Module (it's __call__ fn by default.)\n", " x = nn.jit(nn.Dense)(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.PRNGKey(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = MLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "D1tfTdRjyJYK" }, "source": [ "## Remat" ] }, { "cell_type": "markdown", "metadata": { "id": "goiHMi4qyLiZ" }, "source": [ "For memory-expensive computations, we can `remat` our method to recompute a Module's output during a backwards pass.\n", "\n", "_Known Gotcha_: at the moment, the decorator changes the RNG stream slightly, so comparing remat'd and undecorated initializations will look different." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "sogMxDQpyMZE", "outputId": "7fe8e13b-7dd6-4e55-ee50-ce334e8ed178", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}\n", "output:\n", " [[-0.14814317 0.06889858 -0.19695625 0.12019286 0.02068037]\n", " [-0.04439102 -0.06698258 -0.11579747 -0.19906905 -0.04342325]\n", " [-0.08875751 -0.13392815 -0.23153095 -0.39802808 -0.0868225 ]\n", " [-0.01606487 -0.02424064 -0.04190649 -0.07204203 -0.01571464]]\n" ] } ], "source": [ "class RematMLP(nn.Module):\n", " features: Sequence[int]\n", " # For all transforms, we can annotate a method, or wrap an existing\n", " # Module class. Here we annotate the method.\n", " @nn.remat\n", " @nn.compact\n", " def __call__(self, inputs):\n", " x = inputs\n", " for i, feat in enumerate(self.features):\n", " x = nn.Dense(feat, name=f'layers_{i}')(x)\n", " if i != len(self.features) - 1:\n", " x = nn.relu(x)\n", " return x\n", "\n", "key1, key2 = random.split(random.PRNGKey(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = RematMLP(features=[3,4,5])\n", "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, { "cell_type": "markdown", "metadata": { "id": "l0pJtxVwyCgp" }, "source": [ "## Vmap" ] }, { "cell_type": "markdown", "metadata": { "id": "TqVbjhOkyEaj" }, "source": [ "You can now `vmap` Modules inside. The transform has a lot of arguments, they have the usual jax vmap args:\n", " - `in_axes` - an integer or `None` for each input argument\n", " - `out_axes` - an integer or `None` for each output argument\n", " - `axis_size` - the axis size if you need to give it explicitly\n", "\n", "In addition, we provide for each __kind__ of variable it's axis rules:\n", "\n", " - `variable_in_axes` - a dict from kinds to a single integer or `None` specifying the input axes to map\n", " - `variable_out_axes` - a dict from kinds to a single integer or `None` specifying the output axes to map\n", " - `split_rngs` - a dict from RNG-kinds to a bool, specifying whether to split the rng along the axis.\n", "\n", "\n", "Below we show an example defining a batched, multiheaded attention module from a single-headed unbatched attention implementation." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PIGiriD0yFXo", "outputId": "223d880e-c7b2-4210-ebb5-dbfcdd9aed09", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'attention': {'key': {'kernel': (2, 64, 32)}, 'out': {'bias': (2, 64), 'kernel': (2, 32, 64)}, 'query': {'kernel': (2, 64, 32)}, 'value': {'kernel': (2, 64, 32)}}}}\n", "output:\n", " (3, 13, 2)\n" ] } ], "source": [ "class RawDotProductAttention(nn.Module):\n", " attn_dropout_rate: float = 0.1\n", " train: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, query, key, value, bias=None, dtype=jnp.float32):\n", " assert key.ndim == query.ndim\n", " assert key.ndim == value.ndim\n", "\n", " n = query.ndim\n", " attn_weights = lax.dot_general(\n", " query, key,\n", " (((n-1,), (n - 1,)), ((), ())))\n", " if bias is not None:\n", " attn_weights += bias\n", " norm_dims = tuple(range(attn_weights.ndim // 2, attn_weights.ndim))\n", " attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims)\n", " attn_weights = nn.Dropout(self.attn_dropout_rate)(attn_weights,\n", " deterministic=not self.train)\n", " attn_weights = attn_weights.astype(dtype)\n", "\n", " contract_dims = (\n", " tuple(range(n - 1, attn_weights.ndim)),\n", " tuple(range(0, n - 1)))\n", " y = lax.dot_general(\n", " attn_weights, value,\n", " (contract_dims, ((), ())))\n", " return y\n", "\n", "class DotProductAttention(nn.Module):\n", " qkv_features: Optional[int] = None\n", " out_features: Optional[int] = None\n", " train: bool = False\n", "\n", " @nn.compact\n", " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", " out_features = self.out_features or inputs_q.shape[-1]\n", "\n", " QKVDense = functools.partial(\n", " nn.Dense, features=qkv_features, use_bias=False, dtype=dtype)\n", " query = QKVDense(name='query')(inputs_q)\n", " key = QKVDense(name='key')(inputs_kv)\n", " value = QKVDense(name='value')(inputs_kv)\n", "\n", " y = RawDotProductAttention(train=self.train)(\n", " query, key, value, bias=bias, dtype=dtype)\n", "\n", " y = nn.Dense(features=out_features, dtype=dtype, name='out')(y)\n", " return y\n", "\n", "class MultiHeadDotProductAttention(nn.Module):\n", " qkv_features: Optional[int] = None\n", " out_features: Optional[int] = None\n", " batch_axes: Sequence[int] = (0,)\n", " num_heads: int = 1\n", " broadcast_dropout: bool = False\n", " train: bool = False\n", " @nn.compact\n", " def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):\n", " qkv_features = self.qkv_features or inputs_q.shape[-1]\n", " out_features = self.out_features or inputs_q.shape[-1]\n", "\n", " # Make multiheaded attention from single-headed dimension.\n", " Attn = nn.vmap(DotProductAttention,\n", " in_axes=(None, None, None),\n", " out_axes=2,\n", " axis_size=self.num_heads,\n", " variable_axes={'params': 0},\n", " split_rngs={'params': True,\n", " 'dropout': not self.broadcast_dropout})\n", "\n", " # Vmap across batch dimensions.\n", " for axis in reversed(sorted(self.batch_axes)):\n", " Attn = nn.vmap(Attn,\n", " in_axes=(axis, axis, axis),\n", " out_axes=axis,\n", " variable_axes={'params': None},\n", " split_rngs={'params': False, 'dropout': False})\n", "\n", " # Run the vmap'd class on inputs.\n", " y = Attn(qkv_features=qkv_features // self.num_heads,\n", " out_features=out_features,\n", " train=self.train,\n", " name='attention')(inputs_q, inputs_kv, bias)\n", "\n", " return y.mean(axis=-2)\n", "\n", "\n", "key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)\n", "x = random.uniform(key1, (3, 13, 64))\n", "\n", "model = functools.partial(\n", " MultiHeadDotProductAttention,\n", " broadcast_dropout=False,\n", " num_heads=2,\n", " batch_axes=(0,))\n", "\n", "init_variables = model(train=False).init({'params': key2}, x, x)\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "\n", "y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n", "print('output:\\n', y.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "U-bDSQElvM09" }, "source": [ "## Scan" ] }, { "cell_type": "markdown", "metadata": { "id": "8oiRXIC6xQ--" }, "source": [ "Scan allows us to apply `lax.scan` to Modules, including their parameters and mutable variables. To use it we have to specify how we want each \"kind\" of variable to be transformed. For scanned variables we specify similar to vmap via in `variable_in_axes`, `variable_out_axes`:\n", " - `nn.broadcast` broadcast the variable kind across the scan steps as a constant\n", " - `` scan along `axis` for e.g. unique parameters at each step\n", "\n", "OR we specify that the variable kind is to be treated like a \"carry\" by passing to the `variable_carry` argument.\n", "\n", "Further, for `scan`'d variable kinds, we further specify whether or not to split the rng at each step." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oxA_lWm7tH2B", "outputId": "7d9ebed3-64de-4ca8-9dce-4b09ba9e31a1", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initialized parameter shapes:\n", " {'params': {'lstm_cell': {'hf': {'bias': (2,), 'kernel': (2, 2)}, 'hg': {'bias': (2,), 'kernel': (2, 2)}, 'hi': {'bias': (2,), 'kernel': (2, 2)}, 'ho': {'bias': (2,), 'kernel': (2, 2)}, 'if': {'kernel': (2, 2)}, 'ig': {'kernel': (2, 2)}, 'ii': {'kernel': (2, 2)}, 'io': {'kernel': (2, 2)}}}}\n", "output:\n", " ((DeviceArray([[-0.562219 , 0.92847174]], dtype=float32), DeviceArray([[-0.31570646, 0.2885693 ]], dtype=float32)), DeviceArray([[[-0.08265854, 0.01302483],\n", " [-0.10249066, 0.21991298],\n", " [-0.26609066, 0.22519003],\n", " [-0.27982554, 0.28393182],\n", " [-0.31570646, 0.2885693 ]]], dtype=float32))\n" ] } ], "source": [ "class SimpleScan(nn.Module):\n", " @nn.compact\n", " def __call__(self, xs):\n", " dummy_rng = random.PRNGKey(0)\n", " init_carry = nn.LSTMCell.initialize_carry(dummy_rng,\n", " xs.shape[:1],\n", " xs.shape[-1])\n", " LSTM = nn.scan(nn.LSTMCell,\n", " in_axes=1, out_axes=1,\n", " variable_broadcast='params',\n", " split_rngs={'params': False})\n", " return LSTM(name=\"lstm_cell\")(init_carry, xs)\n", "\n", "key1, key2 = random.split(random.PRNGKey(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "model = SimpleScan()\n", "init_variables = model.init(key2, xs)\n", "\n", "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "\n", "y = model.apply(init_variables, xs)\n", "print('output:\\n', y)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Flax 2 (\"Linen\")", "provenance": [], "toc_visible": true }, "interpreter": { "hash": "50100d07a2a27af6847cdedde67ae5e97c9798a9e1a9ae21eed9ecf69a9f619c" }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3.8.11 ('.venv': venv)", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 0 }