flax.training package#
Checkpoints#
Checkpointing helper functions.
Handles saving and restoring optimizer checkpoints based on step-number or other numerical metric in filename. Cleans up older / worse-performing checkpoint files.
- flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, orbax_checkpointer=None)[source]#
Save a checkpoint of the model. Suitable for single-host.
In this method, every JAX process saves the checkpoint on its own. Do not use it if you have multiple processes and you intend for them to save data to a common directory (e.g., a GCloud bucket). To save multi-process checkpoints to a shared storage or to save
GlobalDeviceArray``s, use ``save_checkpoint_multiprocess()
instead.Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling
async_manager.wait_previous_save()
.Example usage:
>>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': array([1, 2, 3], dtype=int32), 'b': array([1, 1, 1], dtype=int32)}
- Parameters
ckpt_dir – str or pathlib-like path to store checkpoint files in.
target – serializable flax object, usually a flax optimizer.
step – int or float: training step number or other metric number.
prefix – str: checkpoint file name prefix.
keep – number of past checkpoint files to keep.
overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).
keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).
async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.
orbax_checkpointer – if defined, the save will be done by ocp. In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an
orbax_checkpointer
is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers.
- Returns
Filename of saved checkpoint.
- flax.training.checkpoints.save_checkpoint_multiprocess(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None, orbax_checkpointer=None)[source]#
Save a checkpoint of the model in multi-process environment.
Use this method to save ``GlobalDeviceArray``s, or to save data to a common directory. Only process 0 will save the main checkpoint file and remove old checkpoint files.
Pre-emption safe by writing to temporary before a final rename and cleanup of past files. However, if async_manager or gda_manager is used, the final commit will happen inside an async callback, which can be explicitly waited by calling
async_manager.wait_previous_save()
orgda_manager.wait_until_finished()
.- Parameters
ckpt_dir – str or pathlib-like path to store checkpoint files in.
target – serializable flax object, usually a flax optimizer.
step – int or float: training step number or other metric number.
prefix – str: checkpoint file name prefix.
keep – number of past checkpoint files to keep.
overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).
keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).
async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.
gda_manager – required if target contains a JAX GlobalDeviceArray. Will save the GDAs to a separate subdirectory with postfix “_gda” asynchronously. Same as async_manager, this will block subsequent saves.
orbax_checkpointer – if defined, the save will be done by Orbax In the future, all Flax checkpointing features will be migrated to Orbax, and starting to use an
orbax_checkpointer
is recommended. Please check out the checkpointing guide (https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#save-checkpoints) for how to use Orbax checkpointers.
- Returns
Filename of saved checkpoint.
- flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#
Retrieve the path of the latest checkpoint in a directory.
- Parameters
ckpt_dir – str: directory of checkpoints to restore from.
prefix – str: name prefix of checkpoint files.
- Returns
The latest checkpoint path or None if no checkpoints were found.
- flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None, allow_partial_mpa_restoration=False, orbax_checkpointer=None, orbax_transforms=None)[source]#
Restore last/best checkpoint from checkpoints in path.
Sorts the checkpoint files naturally, returning the highest-valued file, e.g.:
ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5
Example usage:
>>> from flax.training import checkpoints >>> import jax.numpy as jnp >>> import tempfile >>> with tempfile.TemporaryDirectory() as dir_path: ... test_object = { ... 'a': jnp.array([1, 2, 3], jnp.int32), ... 'b': jnp.array([1, 1, 1], jnp.int32), ... } ... file_path = checkpoints.save_checkpoint( ... dir_path, target=test_object, step=0, prefix='test_', keep=1 ... ) ... restored_object = checkpoints.restore_checkpoint( ... file_path, target=None ... ) >>> restored_object {'a': array([1, 2, 3], dtype=int32), 'b': array([1, 1, 1], dtype=int32)}
- Parameters
ckpt_dir – str: checkpoint file or directory of checkpoints to restore from.
target – matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is.
step – int or float: step number to load or None to load latest. If specified, ckpt_dir must be a directory.
prefix – str: name prefix of checkpoint files.
parallel – bool: whether to load seekable checkpoints in parallel, for speed.
gda_manager – required if checkpoint contains a multiprocess array (GlobalDeviceArray or jax Array from pjit). Will read the arrays from the separate subdirectory with postfix “_gda”.
allow_partial_mpa_restoration – If true, the given
target
doesn’t have to contain all valid multiprocess arrays. As a result, the restored Pytree may have some MPAs not restored correctly. Use this if you cannot provide a fully validtarget
and don’t need all the MPAs in the checkpoint to be restored.orbax_checkpointer – the
ocp.Checkpointer
that handles the underlying restore, if the given checkpoint is saved with ocp.orbax_transforms – the Orbax transformations that will be passed into
orbax_checkpointer.restore()
call.
- Returns
Restored
target
updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-intarget
unchanged. If a file path is specified and is not found, the passed-intarget
will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created.
- flax.training.checkpoints.convert_pre_linen(params)[source]#
Converts a pre-Linen parameter pytree.
In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class.
Consider the following module:
class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(1, 1)(x) x = nn.Dense(1)(x) return x
In pre-Linen the resulting params would have had the structure:
{'Conv_0': { ... }, 'Dense_1': { ... } }
With Linen the resulting params would instead have had the structure:
{'Conv_0': { ... }, 'Dense_0': { ... } }
To convert from pre-Linen format to Linen simply call:
params = convert_pre_linen(pre_linen_params)
Note that you can also use this utility to convert pre-Linen collections because they’re following the same module naming. Note though that collections were “flat” in pre-Linen and first need to be unflattened before they can be used with this function:
batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in pre_linen_model_state.as_dict().items() }))
Then Linen variables can be defined from these converted collections:
variables = {'params': params, 'batch_stats': batch_stats}
- Parameters
params – Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen).
- Returns
Parameter pytree with Linen submodule naming.
Learning rate schedules#
Learning rate schedules used in FLAX image classification examples.
Note that with FLIP #1009 learning rate schedules in flax.training
are
effectively deprecated in favor of Optax schedules. Please refer to
Optimizer Schedules for more information.
- flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#
Create a constant learning rate schedule with optional warmup.
Note that with FLIP #1009 learning rate schedules in
flax.training
are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first
warmup_length
epochs
- Returns
Function
f(step) -> lr
that computes the learning rate for a given step.
- flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#
Create a stepped learning rate schedule with optional warmup.
Note that with FLIP #1009 learning rate schedules in
flax.training
are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the
lr_sched_steps
parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:[ [30, 0.1], [60, 0.01], [80, 0.001] ]
This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
lr_sched_steps – the schedule as a list of steps, each of which is a
[epoch, lr_factor]
pair; the step occurs at epochepoch
and sets the learning rate tobase_learning_rage * lr_factor
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first
warmup_length
epochs
- Returns
Function
f(step) -> lr
that computes the learning rate for a given step.
- flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#
Create a cosine learning rate schedule with optional warmup.
Note that with FLIP #1009 learning rate schedules in
flax.training
are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training.
This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
halfcos_epochs – the number of epochs to complete half a cosine wave; normally the number of epochs used for training
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first
warmup_length
epochs
- Returns
Function
f(step) -> lr
that computes the learning rate for a given step.
Train state#
- class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import flax.linen as nn >>> from flax.training.train_state import TrainState >>> import jax, jax.numpy as jnp >>> import optax >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 2)) >>> model = nn.Dense(2) >>> variables = model.init(jax.random.key(0), x) >>> tx = optax.adam(1e-3) >>> state = TrainState.create( ... apply_fn=model.apply, ... params=variables['params'], ... tx=tx) >>> def loss_fn(params, x, y): ... predictions = state.apply_fn({'params': params}, x) ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) Array(3.3514676, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) Array(3.343844, dtype=float32)
Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections).
For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.
- Parameters
step – Counter starts at 0 and is incremented by every call to
.apply_gradients()
.apply_fn – Usually set to
model.apply()
. Kept in this dataclass for convenience to have a shorter params list for thetrain_step()
function in your training loop.params – The parameters to be updated by
tx
and used byapply_fn
.tx – An Optax gradient transformation.
opt_state – The state for
tx
.
- apply_gradients(*, grads, **kwargs)[source]#
Updates
step
,params
,opt_state
and**kwargs
in return value.Note that internally this function calls
.tx.update()
followed by a call tooptax.apply_updates()
to updateparams
andopt_state
.- Parameters
grads – Gradients that have the same pytree structure as
.params
.**kwargs – Additional dataclass attributes that should be
.replace()
-ed.
- Returns
An updated instance of
self
withstep
incremented by one,params
andopt_state
updated by applyinggrads
, and additional attributes replaced as specified bykwargs
.
Early Stopping#
- class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False, has_improved=False)[source]#
Early stopping to avoid overfitting during training.
The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:
>>> from flax.training.early_stopping import EarlyStopping >>> def train_epoch(optimizer, train_ds, batch_size, epoch, input_rng): ... ... ... loss = [4, 3, 3, 3, 2, 2, 2, 2, 1, 1][epoch] ... return None, {'loss': loss} >>> early_stop = EarlyStopping(min_delta=1e-3, patience=2) >>> optimizer = None >>> for epoch in range(10): ... optimizer, train_metrics = train_epoch( ... optimizer=optimizer, train_ds=None, batch_size=None, epoch=epoch, input_rng=None) ... early_stop = early_stop.update(train_metrics['loss']) ... if early_stop.should_stop: ... print(f'Met early stopping criteria, breaking at epoch {epoch}') ... break Met early stopping criteria, breaking at epoch 7
- min_delta#
Minimum delta between updates to be considered an improvement.
- Type
float
- patience#
Number of steps of no improvement before stopping.
- Type
int
- best_metric#
Current best metric value.
- Type
float
- patience_count#
Number of steps since last improving update.
- Type
int
- should_stop#
Whether the training loop should stop to avoid overfitting.
- Type
bool
- has_improved#
Whether the metric has improved greater or equal to the min_delta in the last
.update
call.- Type
bool
Common Utilities#
- flax.training.common_utils.shard(xs)[source]#
Helper for pmap to shard a pytree of arrays by local_device_count.
- Parameters
xs – a pytree of arrays.
- Returns
A matching pytree with arrays’ leading dimensions sharded by the local device count.
- flax.training.common_utils.shard_prng_key(prng_key)[source]#
Helper to shard (aka split) a PRNGKey for use with pmap’d functions.
PRNG keys can be used at train time to drive stochastic modules e.g. Dropout. We would like a different PRNG key for each local device so that we end up with different random numbers on each one, hence we split our PRNG key.
- Parameters
prng_key – JAX PRNGKey
- Returns
A new array of PRNGKeys with leading dimension equal to local device count.
- flax.training.common_utils.stack_forest(forest)[source]#
Helper function to stack the leaves of a sequence of pytrees.
- Parameters
forest – a sequence of pytrees (e.g tuple or list) of matching structure whose leaves are arrays with individually matching shapes.
- Returns
- A single pytree of the same structure whose leaves are individually
stacked arrays.
- flax.training.common_utils.get_metrics(device_metrics)[source]#
Helper utility for pmap, gathering replicated timeseries metric data.
- Parameters
device_metrics – replicated, device-resident pytree of metric data, whose leaves are presumed to be a sequence of arrays recorded over time.
- Returns
A pytree of unreplicated, host-resident, stacked-over-time arrays useful for computing host-local statistics and logging.
- flax.training.common_utils.onehot(labels, num_classes, on_value=1.0, off_value=0.0)[source]#
Create a dense one-hot version of an indexed array.
NB: consider using the more standard
jax.nn.one_hot
instead.- Parameters
labels – an n-dim JAX array whose last dimension contains integer indices.
num_classes – the maximum possible index.
on_value – the “on” value for the one-hot array, defaults to 1.0.
off_value – the “off” value for the one-hot array, defaults to 0.0.
- Returns
A (n+1)-dim array whose last dimension contains one-hot vectors of length num_classes.