Flax#

Neural networks with JAX


Flax delivers an end-to-end, flexible, user experience for researchers who use JAX with neural networks. Flax exposes the full power of JAX. It is made up of loosely coupled libraries, which are showcased with end-to-end integrated guides and examples.

Features#

Safety

Flax is designed for correctness and safety. Thanks to its immutable Modules and Functional API, Flax helps mitigate bugs that araise when handling state in JAX.

Control

Flax grants more fine grained control and expressivity than most Neural Network frameworks via its Variable Collections, RNG Collections and Mutability conditions.

Functional API

Flax’s functional API radically redefines what Modules can do via lifted transformations like vmap, scan, etc, while also enabling seamless integration with other JAX libraries like Optax and Chex.

Terse Code

Flax’s compact Modules enables submodules to be defined directly at their callsite, leading to code that is easier to read and avoids repetition.


Installation#

pip install flax

Flax installs the vanilla CPU version of JAX, if you need a custom version please check out JAX’s installation page.

Basic usage#

class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(16)(x)                # inline submodules
    x = nn.relu(x)
    x = nn.Dense(16)(x)                # inline submodules
    return x

model = MLP()                          # create model

x = jnp.ones((4, 16))                  # get some data
variables = model.init(PRNGKey(42), x) # initialize weights
y = model.apply(variables, x)          # make forward pass

Learn more#

Getting Started
Guides
Advanced Topics
Examples
API Reference

Ecosystem#

Flax is used by hundreds of projects (and growing), both in the open source community and within Google. Notable examples include:

NLP and Computer Vision models

Model for Text-to-Image generation

540 Billion parameter model for text generation

Text-to-Image Diffusion Models

Large scale Computer Vision models

Large Language Models

On-device differentiable RL environments