{ "cells": [ { "cell_type": "markdown", "id": "6e9134fa", "metadata": {}, "source": [ "# Save and load checkpoints\n", "\n", "In this guide, you will learn about saving and loading checkpoints with Flax and [Orbax](https://github.com/google/orbax). With Flax, you can save and load model parameters, metadata, and a variety of Python data using Orbax. \n", "\n", "Orbax provides a customizable and flexible API for various array types and storage formats. In addition, Flax provides basic features for versioning, automatic bookkeeping of past checkpoints, and asynchronous saving to reduce training wait time.\n", "\n", "> **_Ongoing migration:_** In the foreseeable future, Flax's checkpointing functionality will gradually be migrated to Orbax from `flax.training.checkpoints`. All existing features in the Flax API will continue to be supported, but the API will change. You are encouraged to try out the new API by creating an [`orbax.Checkpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/checkpointer.py) and pass it in your Flax API calls as an argument `orbax_checkpointer`, as demonstrated later in this guide. This guide provides the most up-to-date code examples for using Orbax and Flax for checkpointing.\n", "\n", "This guide covers the following:\n", "\n", "* Basic saving and loading of checkpoints with [`orbax.Checkpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/checkpointer.py) and [`flax.training.checkpoints.save_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint).\n", "* More flexible and sustainable ways to load checkpoints ([`flax.training.checkpoints.restore_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.restore_checkpoint)).\n", "* How to save and load checkpoints when you run in multi-host scenarios with\n", "[`flax.training.checkpoints.save_checkpoint_multiprocess`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess)." ] }, { "cell_type": "markdown", "id": "5a2f6aae", "metadata": {}, "source": [ "## Setup\n", "\n", "Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation)." ] }, { "cell_type": "code", "execution_count": 1, "id": "e80f8743", "metadata": {}, "outputs": [], "source": [ "!pip install -U -qq flax orbax\n", "\n", "# Orbax needs to enable asyncio in a Colab environment.\n", "!pip install -qq nest_asyncio" ] }, { "cell_type": "markdown", "id": "-icO30rwmKYj", "metadata": {}, "source": [ "Note: Before running `import jax`, create eight fake devices to mimic [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) this notebook. Note that the order of imports is important here. The `os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend. This means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell." ] }, { "cell_type": "code", "execution_count": 2, "id": "ArKLnsyGRxGv", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"XLA_FLAGS\"] = '--xla_force_host_platform_device_count=8'" ] }, { "cell_type": "code", "execution_count": 3, "id": "SJT9DTxTytjn", "metadata": {}, "outputs": [], "source": [ "from typing import Optional, Any\n", "import shutil\n", "\n", "import numpy as np\n", "import jax\n", "from jax import random, numpy as jnp\n", "\n", "import flax\n", "from flax import linen as nn\n", "from flax.training import checkpoints, train_state\n", "from flax import struct, serialization\n", "import orbax.checkpoint as orbax\n", "\n", "import optax\n", "import nest_asyncio\n", "nest_asyncio.apply()" ] }, { "cell_type": "markdown", "id": "40d434cd", "metadata": {}, "source": [ "## Save checkpoints\n", "\n", "In Flax, you save and load any given JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html) using the `flax.training.checkpoints` package. This includes not only typical Python and NumPy containers, but also customized classes extended from [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass). That means you can store almost any data generated—not only your model parameters, but any arrays/dictionaries, metadata/configs, and so on.\n", "\n", "Create a pytree with many data structures and containers, and play with it:" ] }, { "cell_type": "code", "execution_count": 4, "id": "56dec3f6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': TrainState(step=1, apply_fn=, params=FrozenDict({\n", " bias: DeviceArray([-0.001, -0.001, -0.001], dtype=float32),\n", " kernel: DeviceArray([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", " }), tx=GradientTransformation(init=.init_fn at 0x14d093b80>, update=.update_fn at 0x14d0938b0>), opt_state=(EmptyState(), EmptyState())),\n", " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [DeviceArray([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# A simple model with one linear layer.\n", "key1, key2 = random.split(random.PRNGKey(0))\n", "x1 = random.normal(key1, (5,)) # A simple JAX array.\n", "model = nn.Dense(features=3)\n", "variables = model.init(key2, x1)\n", "\n", "# Flax's TrainState is a pytree dataclass and is supported in checkpointing.\n", "# Define your class with `@flax.struct.dataclass` decorator to make it compatible.\n", "tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.\n", "state = train_state.TrainState.create(\n", " apply_fn=model.apply,\n", " params=variables['params'],\n", " tx=tx)\n", "# Perform a simple gradient update similar to the one during a normal training workflow.\n", "state = state.apply_gradients(grads=jax.tree_map(jnp.ones_like, state.params))\n", "\n", "# Some arbitrary nested pytree with a dictionary, a string, and a NumPy array.\n", "config = {'dimensions': np.array([5, 3]), 'name': 'dense'}\n", "\n", "# Bundle everything together.\n", "ckpt = {'model': state, 'config': config, 'data': [x1]}\n", "ckpt" ] }, { "cell_type": "markdown", "id": "6fc59dfa", "metadata": {}, "source": [ "Now save the checkpoint with Flax and Orbax. You can add annotations like step number, prefix, and so on to your checkpoint.\n", "\n", "When saving a checkpoint, Flax will bookkeep the existing checkpoints based on your arguments. For example, by setting `overwrite=False` in [`flax.checkpoints.save_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint), Flax will not automatically save your checkpoint if there is already a step that is equal to or newer than the current one presently in the checkpoint directory. By setting `keep=2`, Flax will keep a maximum of 2 checkpoints in the directory. Learn more in the [API reference](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#module-flax.training.checkpoints).\n", "\n", "You can start to use Orbax to handle the underlying save by creating an [`orbax.Checkpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/checkpointer.py), and pass it into the `flax.checkpoints.save_checkpoint` call." ] }, { "cell_type": "code", "execution_count": null, "id": "4cdb35ef", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Import Flax Checkpoints.\n", "from flax.training import checkpoints\n", "\n", "ckpt_dir = 'tmp/flax-checkpointing'\n", "\n", "if os.path.exists(ckpt_dir):\n", " shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.\n", "\n", "orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())\n", "checkpoints.save_checkpoint(ckpt_dir=ckpt_dir,\n", " target=ckpt,\n", " step=0,\n", " overwrite=False,\n", " keep=2,\n", " orbax_checkpointer=orbax_checkpointer)" ] }, { "cell_type": "markdown", "id": "6b658bd1", "metadata": {}, "source": [ "## Restore checkpoints\n", "\n", "To restore a checkpoint, use [`flax.training.checkpoints.restore_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.restore_checkpoint) and pass in the checkpoint directory. Flax will automatically select the latest checkpoint in the directory. You can also choose to specify a step number or the path of the checkpoint file.\n", "\n", "With the migration to Orbax in progress, `restore_checkpoint` can automatically identify whether a checkpoint is saved in the legacy (Flax) or Orbax version, and restore the pytree correctly.\n", "\n", "You can always restore a pytree out of your checkpoints by setting `target=None`." ] }, { "cell_type": "code", "execution_count": 6, "id": "150b20a0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)},\n", " 'model': {'opt_state': {'0': {}, '1': {}},\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", " 'step': 1}}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_restored = checkpoints.restore_checkpoint(ckpt_dir=ckpt_dir, target=None)\n", "raw_restored" ] }, { "cell_type": "markdown", "id": "987b981f", "metadata": {}, "source": [ "However, when using `target=None`, the restored `raw_restored` will be different from the original `ckpt` in the following ways:\n", "\n", "1. There is no TrainState now, and only some raw weights and Optax state numbers remain.\n", "1. `metadata.dimensions` and `data` should be arrays, but restored as dictionaries with integers as keys.\n", "1. Previously, `data[0]` was a JAX NumPy array (`jnp.array`) —now it's a NumPy array (`numpy.array`).\n", "\n", "While (3) would not affect future work because JAX will [automatically convert](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) NumPy arrays to JAX arrays once the computation starts, (1) and (2) may lead to confusions.\n", "\n", "To resolve this, you should pass an example `target` in `flax.training.checkpoints.restore_checkpoint` to let Flax know exactly what structure it should restore to. The `target` should introduce any custom Flax dataclasses explicitly, and have the same structure as the saved checkpoint.\n", "\n", "It's often recommended to refactor out the process of initializing a checkpoint's structure (for example, a [`TrainState`](https://flax.readthedocs.io/en/latest/flip/1009-optimizer-api.html?#train-state)), so that saving/loading is easier and less error-prone. This is because complicated objects like `apply_fn` and `tx` (optimizer) are not stored in the checkpoint file and must be initiated by code." ] }, { "cell_type": "code", "execution_count": 7, "id": "58f42513", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)],\n", " 'model': TrainState(step=1, apply_fn=, params=FrozenDict({\n", " bias: array([-0.001, -0.001, -0.001], dtype=float32),\n", " kernel: array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", " }), tx=GradientTransformation(init=.init_fn at 0x14d093b80>, update=.update_fn at 0x14d0938b0>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "empty_state = train_state.TrainState.create(\n", " apply_fn=model.apply,\n", " params=jax.tree_map(np.zeros_like, variables['params']), # values of the tree leaf doesn't matter\n", " tx=tx,\n", ")\n", "empty_config = {'dimensions': np.array([0, 0]), 'name': ''}\n", "target = {'model': empty_state, 'config': empty_config, 'data': [jnp.zeros_like(x1)]}\n", "state_restored = checkpoints.restore_checkpoint(ckpt_dir, target=target, step=0)\n", "state_restored" ] }, { "cell_type": "markdown", "id": "136a300a", "metadata": {}, "source": [ "### Backward/forward dataclass compatibility\n", "\n", "The flexibility of using *Flax dataclasses*—[`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass)—means that changes in Flax dataclass fields may break your existing checkpoints. For example, if you decide to add a field `batch_stats` to your `TrainState` (like when using [batch normalization](https://flax.readthedocs.io/en/latest/guides/batch_norm.html)), old checkpoints without this field may not be successfully restored. Same goes for removing a field in your dataclass.\n", "\n", "Note: Flax supports [`flax.struct.dataclass`](https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html#flax.struct.dataclass), not Python's built-in `dataclasses.dataclass`." ] }, { "cell_type": "code", "execution_count": 8, "id": "be65d4af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError when target state has an unmentioned field:\n", "Missing field batch_stats in state dict while restoring an instance of CustomTrainState, at path ./model\n", "\n", "ValueError when target state misses a recorded field:\n", "Unknown field(s) \"batch_stats\" in state dict while restoring an instance of TrainState at path ./model\n" ] } ], "source": [ "class CustomTrainState(train_state.TrainState):\n", " batch_stats: Any = None\n", "\n", "custom_state = CustomTrainState.create(\n", " apply_fn=state.apply_fn,\n", " params=state.params,\n", " tx=state.tx,\n", " batch_stats=np.arange(10),\n", ")\n", "\n", "# Use a custom state to read the old `TrainState` checkpoint.\n", "custom_target = {'model': custom_state, 'config': None, 'data': [jnp.zeros_like(x1)]}\n", "try:\n", " checkpoints.restore_checkpoint(ckpt_dir, target=custom_target, step=0)\n", "except ValueError as e:\n", " print('ValueError when target state has an unmentioned field:')\n", " print(e)\n", " print('')\n", "\n", "\n", "# Use the old `TrainState` to read the custom state checkpoint.\n", "custom_ckpt = {'model': custom_state, 'config': config, 'data': [x1]}\n", "checkpoints.save_checkpoint(ckpt_dir, custom_ckpt, step=1, overwrite=True, \n", " keep=2, orbax_checkpointer=orbax_checkpointer)\n", "try:\n", " checkpoints.restore_checkpoint(ckpt_dir, target=target, step=1)\n", "except ValueError as e:\n", " print('ValueError when target state misses a recorded field:')\n", " print(e)\n", " " ] }, { "cell_type": "markdown", "id": "379c2255", "metadata": {}, "source": [ "It is recommended to keep your checkpoints up to date with your pytree dataclass definitions. You can keep a copy of your code along with your checkpoints.\n", "\n", "But if you must restore checkpoints and Flax dataclasses with incompatible fields, you can manually add/remove corresponding fields before passing in the correct target structure:" ] }, { "cell_type": "code", "execution_count": 9, "id": "29fd1e33", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': CustomTrainState(step=0, apply_fn=, params=FrozenDict({\n", " bias: array([-0.001, -0.001, -0.001], dtype=float32),\n", " kernel: array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", " }), tx=GradientTransformation(init=.init_fn at 0x14d093b80>, update=.update_fn at 0x14d0938b0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n", " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)]}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pass no target to get a raw state dictionary first.\n", "raw_state_dict = checkpoints.restore_checkpoint(ckpt_dir, target=None, step=0)\n", "# Add/remove fields as needed.\n", "raw_state_dict['model']['batch_stats'] = np.arange(10)\n", "# Restore the classes with correct target now\n", "serialization.from_state_dict(custom_target, raw_state_dict)" ] }, { "cell_type": "markdown", "id": "a6b39501", "metadata": {}, "source": [ "## Asynchronized checkpointing\n", "\n", "Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training.\n", "\n", "You can do this by creating an [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) (as demonstrated in the code cell below) and let it track your save thread.\n", "\n", "Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps.\n", "\n", "Whenever you want to explicitly wait until an async save is done, you can call `async_checkpointer.wait_until_finished()`. Alternatively, you can pass in `orbax_checkpointer=async_checkpointer` when running `restore_checkpoint` and Flax will automatically wait and restore safely." ] }, { "cell_type": "code", "execution_count": 10, "id": "85be68a6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)},\n", " 'model': {'opt_state': {'0': {}, '1': {}},\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", " 'step': 1}}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# `orbax.AsyncCheckpointer` needs some multi-process initialization, because it was\n", "# originally designed for multi-process large model checkpointing.\n", "# For Python notebooks or other single-process setting, just set up with `num_processes=1`.\n", "# Refer to https://jax.readthedocs.io/en/latest/multi_process.html#initializing-the-cluster\n", "# for how to set it up in multi-process scenarios.\n", "jax.distributed.initialize(\"localhost:8889\", num_processes=1, process_id=0)\n", "\n", "async_checkpointer = orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler(), timeout_secs=50)\n", "\n", "# Mimic a training loop here:\n", "for step in range(2, 3):\n", " checkpoints.save_checkpoint(ckpt_dir, ckpt, step=2, overwrite=True, keep=3,\n", " orbax_checkpointer=async_checkpointer)\n", " # ... Continue with your work...\n", "\n", "# ... Until a time when you want to wait until the save completes:\n", "async_checkpointer.wait_until_finished() # Blocks until the checkpoint saving is completed.\n", "checkpoints.restore_checkpoint(ckpt_dir, target=None, step=2)" ] }, { "cell_type": "markdown", "id": "13e93db6", "metadata": {}, "source": [ "## Multi-host/multi-process checkpointing\n", "\n", "JAX provides a few ways to scale up your code on multiple hosts at the same time. This usually happens when the number of devices (CPU/GPU/TPU) is so large that different devices are managed by different hosts (CPU). To get started on JAX in multi-process settings, check out [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html) and the [distributed array guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).\n", "\n", "In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX [`pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html), a large multi-process array can have its data sharded across different devices (check out the `pjit` [JAX-101 tutorial](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html)). When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket.\n", "\n", "Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves.\n", "\n", "To save multi-process arrays, use [`flax.training.checkpoints.save_checkpoint_multiprocess()`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) in place of `save_checkpoint()` and with the same arguments.\n", "\n", "Unfortunately, Python Jupyter notebooks are single-host only and cannot activate the multi-host mode. You can treat the following code as an example for running your multi-host checkpointing:" ] }, { "cell_type": "code", "execution_count": 11, "id": "d199c8fa", "metadata": {}, "outputs": [], "source": [ "# Multi-host related imports.\n", "from jax.experimental import maps, PartitionSpec, pjit" ] }, { "cell_type": "code", "execution_count": 12, "id": "ubdUvyMrhD-1", "metadata": {}, "outputs": [], "source": [ "# Create a multi-process array.\n", "mesh_shape = (4, 2)\n", "devices = np.asarray(jax.devices()).reshape(*mesh_shape)\n", "mesh = maps.Mesh(devices, ('x', 'y'))\n", "\n", "f = pjit.pjit(\n", " lambda x: x,\n", " in_axis_resources=None,\n", " out_axis_resources=PartitionSpec('x', 'y'))\n", "\n", "with maps.Mesh(mesh.devices, mesh.axis_names):\n", " mp_array = f(np.arange(8 * 2).reshape(8, 2))\n", "\n", "# Make it a pytree as usual.\n", "mp_ckpt = {'model': mp_array}" ] }, { "cell_type": "markdown", "id": "edc355ce", "metadata": {}, "source": [ "### Example: Save a checkpoint in a multi-process setting with `save_checkpoint_multiprocess`\n", "\n", "The arguments in [`flax.training.checkpoints.save_checkpoint_multiprocess`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint_multiprocess) are the same as in [`flax.training.checkpoints.save_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.save_checkpoint).\n", "\n", "If your checkpoint is too large, you can specify `timeout_secs` in the manager and give it more time to finish writing." ] }, { "cell_type": "code", "execution_count": 13, "id": "5d10039b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'tmp/flax-checkpointing/checkpoint_3'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "async_checkpointer = orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler(), timeout_secs=50)\n", "checkpoints.save_checkpoint_multiprocess(ckpt_dir, \n", " mp_ckpt, \n", " step=3, \n", " overwrite=True, \n", " keep=4, \n", " orbax_checkpointer=async_checkpointer)" ] }, { "cell_type": "markdown", "id": "d954c3c7", "metadata": {}, "source": [ "### Example: Restoring a checkpoint with `flax.training.checkpoints.restore_checkpoint`\n", "\n", "Note that, when using [`flax.training.checkpoints.restore_checkpoint`](https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.checkpoints.restore_checkpoint), you need to pass a `target` with valid multi-process arrays at the correct structural location. Flax only uses the `target` arrays' meshes and mesh axes to restore the checkpoint. This means that the multi-process array in the `target` arg doesn't have to be as large as your checkpoint's size (the shape of the multi-process array doesn't need to have the same shape as the actual array in your checkpoint)." ] }, { "cell_type": "code", "execution_count": 14, "id": "a9f9724c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model': array([[ 0, 1],\n", " [ 2, 3],\n", " [ 4, 5],\n", " [ 6, 7],\n", " [ 8, 9],\n", " [10, 11],\n", " [12, 13],\n", " [14, 15]], dtype=int32)}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with maps.Mesh(mesh.devices, mesh.axis_names):\n", " mp_smaller_array = f(np.zeros(8).reshape(4, 2))\n", "\n", "mp_target = {'model': mp_smaller_array}\n", "mp_restored = checkpoints.restore_checkpoint(ckpt_dir, \n", " target=mp_target, \n", " step=3,\n", " orbax_checkpointer=async_checkpointer)\n", "mp_restored" ] } ], "metadata": { "colab": { "collapsed_sections": [], "provenance": [] }, "gpuClass": "standard", "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }