{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 🔪 Flax - The Sharp Bits 🔪\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/flax_sharp_bits.ipynb)\n", "\n", "Flax exposes the full power of JAX. And just like when using JAX, there are certain _[\"sharp bits\"](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)_ you may experience when working with Flax. This evolving document is designed to assist you with them.\n", "\n", "First, install and/or update Flax:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "!pip install -qq flax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔪 `flax.linen.Dropout` layer and randomness\n", "\n", "### TL;DR\n", "\n", "When working on a model with dropout (subclassed from [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)), add the `'dropout'` PRNGkey only during the forward pass.\n", "\n", "1. Start with [`jax.random.split()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax-random-split) to explicitly create PRNG keys for `'params'` and `'dropout'`.\n", "2. Add the [`flax.linen.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.Dropout.html#flax.linen.Dropout) layer(s) to your model (subclassed from Flax [`Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)).\n", "3. When initializing the model ([`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply)), there's no need to pass in an extra `'dropout'` PRNG key—just the `'params'` key like in a \"simpler\" model.\n", "4. During the forward pass with [`flax.linen.apply()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply), pass in `rngs={'dropout': dropout_key}`.\n", "\n", "Check out a full example below.\n", "\n", "### Why this works\n", "\n", "- Internally, `flax.linen.Dropout` makes use of [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) to create a key for dropout (check out the [source code](https://github.com/google/flax/blob/5714e57a0dc8146eb58a7a06ed768ed3a17672f9/flax/linen/stochastic.py#L72)).\n", "- Every time `make_rng` is called (in this case, it's done implicitly in `Dropout`), you get a new PRNG key split from the main/root PRNG key.\n", "- `make_rng` still _guarantees full reproducibility_.\n", "\n", "### Background \n", "\n", "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n", "\n", "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.PRNGKey(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", "\n", "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it.\n", "\n", "> Note: Recall that [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module) is the base class for all neural network modules. All layers and models are subclassed from it.\n", "\n", "### Example\n", "\n", "Remember that each of the Flax PRNG streams has a name. The example below uses the `'params'` stream for initializing parameters, as well as the `'dropout'` stream. The PRNG key provided to [`flax.linen.init()`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#init-apply) is the one that seeds the `'params'` PRNG key stream. To draw PRNG keys during the forward pass (with dropout), provide a PRNG key to seed that stream (`'dropout'`) when you call `Module.apply()`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Setup.\n", "import jax\n", "import jax.numpy as jnp\n", "import flax.linen as nn" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Randomness.\n", "seed = 0\n", "root_key = jax.random.PRNGKey(seed=seed)\n", "main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)\n", "\n", "# A simple network.\n", "class MyModel(nn.Module):\n", " num_neurons: int\n", " training: bool\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Dense(self.num_neurons)(x)\n", " # Set the dropout layer with a rate of 50% .\n", " # When the `deterministic` flag is `True`, dropout is turned off.\n", " x = nn.Dropout(rate=0.5, deterministic=not self.training)(x)\n", " return x\n", "\n", "# Instantiate `MyModel` (you don't need to set `training=True` to\n", "# avoid performing the forward pass computation).\n", "my_model = MyModel(num_neurons=3, training=False)\n", "\n", "x = jax.random.uniform(key=main_key, shape=(3, 4, 4))\n", "\n", "# Initialize with `flax.linen.init()`.\n", "# The `params_key` is equivalent to a dictionary of PRNGs.\n", "# (Here, you are providing only one PRNG key.) \n", "variables = my_model.init(params_key, x)\n", "\n", "# Perform the forward pass with `flax.linen.apply()`.\n", "y = my_model.apply(variables, x, rngs={'dropout': dropout_key})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Real-life examples:\n", "\n", "* Applying word dropout to a batch of input IDs (in a [text classification](https://github.com/google/flax/blob/main/examples/sst2/models.py) context).\n", "* Defining a prediction token in a decoder of a [sequence-to-sequence model](https://github.com/google/flax/blob/main/examples/seq2seq/models.py)." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "provenance": [] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3.10.4 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "vscode": { "interpreter": { "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" } } }, "nbformat": 4, "nbformat_minor": 0 }