You will need Python 3.7 or later.

For GPU support, first install jaxlib; please follow the instructions in the JAX readme. If they are not already installed, you will need to install CUDA and CuDNN runtimes.

Then install flax from PyPi:

> pip install flax

To upgrade to the latest version of JAX and Flax, you can use:

> pip install --upgrade pip jax jaxlib
> pip install --upgrade git+