Example: Using pretrained Gemma for inference with Flax NNX#
This example shows how to use Flax NNX to load the Gemma open model files and use them to perform sampling/inference for generating text. You will use Flax NNX gemma
modules written with Flax and JAX for model parameter configuration and inference.
You are recommended to use Google Colab with access to A100 GPU acceleration to run the code.
Installation#
Install the necessary dependencies, including kagglehub
.
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
Download the model#
To use Gemma model, you’ll need a Kaggle account and API key:
To create an account, visit Kaggle and click on ‘Register’.
If/once you have an account, you need to sign in, go to your ‘Settings’, and under ‘API’ click on ‘Create New Token’ to generate and download your Kaggle API key.
In Google Colab, under ‘Secrets’ add your Kaggle username and API key, storing the username as
KAGGLE_USERNAME
and the key asKAGGLE_KEY
. If you are using a Kaggle Notebook for free TPU or other hardware acceleration, it has a key storage feature under ‘Add-ons’ > ‘Secrets’, along with instructions for accessing stored keys.
Then run the cell below.
import kagglehub
kagglehub.login()
If everything went well, it should say Kaggle credentials set. Kaggle credentials successfully validated.
.
Note: In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Now, load the Gemma model you want to try. The code in the next cell utilizes kagglehub.model_download
to download model files.
Note: For larger models, such as gemma 7b
and gemma 7b-it
(instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100.
from IPython.display import clear_output
VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'
Python imports#
from flax import nnx
import sentencepiece as spm
To interact with the Gemma model, you will use the Flax NNX gemma
code from google/flax
examples on GitHub. Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX examples/gemma
on GitHub.
import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Create a temporary directory and clone the `flax` repo.
# Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.
Load and prepare the Gemma model#
First, load the Gemma model parameters for use with Flax.
params = params_lib.load_and_format_params(ckpt_path)
Next, load the tokenizer file constructed using the SentencePiece library.
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True
Then, use the Flax NNX gemma.transformer.TransformerConfig.from_params
function to automatically load the correct configuration from a checkpoint.
Note: The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)