Example: Using pretrained Gemma for inference with Flax NNX

Example: Using pretrained Gemma for inference with Flax NNX#

This example shows how to use Flax NNX to load the Gemma open model files and use them to perform sampling/inference for generating text. You will use Flax NNX gemma modules written with Flax and JAX for model parameter configuration and inference.

Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s Gemini. Read more about Gemma and Gemma 2.

You are recommended to use Google Colab with access to A100 GPU acceleration to run the code.

Installation#

Install the necessary dependencies, including kagglehub.

! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

Download the model#

To use Gemma model, you’ll need a Kaggle account and API key:

  1. To create an account, visit Kaggle and click on ‘Register’.

  2. If/once you have an account, you need to sign in, go to your ‘Settings’, and under ‘API’ click on ‘Create New Token’ to generate and download your Kaggle API key.

  3. In Google Colab, under ‘Secrets’ add your Kaggle username and API key, storing the username as KAGGLE_USERNAME and the key as KAGGLE_KEY. If you are using a Kaggle Notebook for free TPU or other hardware acceleration, it has a key storage feature under ‘Add-ons’ > ‘Secrets’, along with instructions for accessing stored keys.

Then run the cell below.

import kagglehub
kagglehub.login()

If everything went well, it should say Kaggle credentials set. Kaggle credentials successfully validated..

Note: In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Now, load the Gemma model you want to try. The code in the next cell utilizes kagglehub.model_download to download model files.

Note: For larger models, such as gemma 7b and gemma 7b-it (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100.

from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

Python imports#

from flax import nnx
import sentencepiece as spm

To interact with the Gemma model, you will use the Flax NNX gemma code from google/flax examples on GitHub. Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX examples/gemma on GitHub.

import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
  # Create a temporary directory and clone the `flax` repo.
  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import params as params_lib
  import sampler as sampler_lib
  import transformer as transformer_lib
  sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.

Load and prepare the Gemma model#

First, load the Gemma model parameters for use with Flax.

params = params_lib.load_and_format_params(ckpt_path)

Next, load the tokenizer file constructed using the SentencePiece library.

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True

Then, use the Flax NNX gemma.transformer.TransformerConfig.from_params function to automatically load the correct configuration from a checkpoint.

Note: The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)