flax.training package
Contents
flax.training package#
Checkpoints#
Checkpointing helper functions.
Handles saving and restoring optimizer checkpoints based on step-number or other numerical metric in filename. Cleans up older / worse-performing checkpoint files.
- flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None)[source]#
Save a checkpoint of the model.
Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files.
- Parameters
ckpt_dir – str or pathlib-like path to store checkpoint files in.
target – serializable flax object, usually a flax optimizer.
step – int or float: training step number or other metric number.
prefix – str: checkpoint file name prefix.
keep – number of past checkpoint files to keep.
overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).
keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).
async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.
gda_manager – required if target contains a JAX GlobalDeviceArray. Will save the GDAs to a separate subdirectory with postfix “_gda” asynchronously. Same as async_manager, this will block subsequent saves.
- Returns
Filename of saved checkpoint.
- flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#
Retrieve the path of the latest checkpoint in a directory.
- Parameters
ckpt_dir – str: directory of checkpoints to restore from.
prefix – str: name prefix of checkpoint files.
- Returns
The latest checkpoint path or None if no checkpoints were found.
- flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None)[source]#
Restore last/best checkpoint from checkpoints in path.
Sorts the checkpoint files naturally, returning the highest-valued file, e.g.:
ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5
- Parameters
ckpt_dir – str: checkpoint file or directory of checkpoints to restore from.
target – matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is.
step – int: step number to load or None to load latest. If specified, ckpt_dir must be a directory.
prefix – str: name prefix of checkpoint files.
parallel – bool: whether to load seekable checkpoints in parallel, for speed.
gda_manager – required if checkpoint contains a JAX GlobalDeviceArray. Will read the GDAs from the separate subdirectory with postfix “_gda”.
- Returns
Restored target updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in target unchanged. If a file path is specified and is not found, the passed-in target will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created.
- flax.training.checkpoints.convert_pre_linen(params)[source]#
Converts a pre-Linen parameter pytree.
In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class.
Consider the following module:
class Model(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(1, 1)(x) x = nn.Dense(1)(x) return x
In pre-Linen the resulting params would have had the structure:
{'Conv_0': { ... }, 'Dense_1': { ... } }
With Linen the resulting params would instead have had the structure:
{'Conv_0': { ... }, 'Dense_0': { ... } }
To convert from pre-Linen format to Linen simply call:
params = convert_pre_linen(pre_linen_params)
Note that you can also use this utility to convert pre-Linen collections because they’re following the same module naming. Note though that collections were “flat” in pre-Linen and first need to be unflattened before they can be used with this function:
batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({ tuple(k.split('/')[1:]): v for k, v in pre_linen_model_state.as_dict().items() }))
Then Linen variables can be defined from these converted collections:
variables = {'params': params, 'batch_stats': batch_stats}
- Parameters
params – Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen).
- Returns
Parameter pytree with Linen submodule naming.
Learning rate schedules#
Learning rate schedules used in FLAX image classification examples.
Note that with FLIP #1009 learning rate schedules in flax.training
are
effectively deprecated in favor of Optax schedules. Please refer to
Optimizer Schedules for more information.
- flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#
Create a constant learning rate schedule with optional warmup.
Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs
- Returns
Function f(step) -> lr that computes the learning rate for a given step.
- flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#
Create a stepped learning rate schedule with optional warmup.
A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the lr_sched_steps parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:
[ [30, 0.1], [60, 0.01], [80, 0.001] ]
This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
lr_sched_steps – the schedule as a list of steps, each of which is a [epoch, lr_factor] pair; the step occurs at epoch epoch and sets the learning rate to base_learning_rage * lr_factor
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs
- Returns
Function f(step) -> lr that computes the learning rate for a given step.
- flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#
Create a cosine learning rate schedule with optional warmup.
A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training.
This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.
- Parameters
base_learning_rate – the base learning rate
steps_per_epoch – the number of iterations per epoch
halfcos_epochs – the number of epochs to complete half a cosine wave; normally the number of epochs used for training
warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs
- Returns
Function f(step) -> lr that computes the learning rate for a given step.
Train state#
- class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#
Simple train state for the common case with a single Optax optimizer.
Synopsis:
state = TrainState.create( apply_fn=model.apply, params=variables['params'], tx=tx) grad_fn = jax.grad(make_loss_fn(state.apply_fn)) for batch in data: grads = grad_fn(state.params, batch) state = state.apply_gradients(grads=grads)
Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections).
For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.
- Parameters
step – Counter starts at 0 and is incremented by every call to .apply_gradients().
apply_fn – Usually set to model.apply(). Kept in this dataclass for convenience to have a shorter params list for the train_step() function in your training loop.
params – The parameters to be updated by tx and used by apply_fn.
tx – An Optax gradient transformation.
opt_state – The state for tx.
- apply_gradients(*, grads, **kwargs)[source]#
Updates step, params, opt_state and **kwargs in return value.
Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.
- Parameters
grads – Gradients that have the same pytree structure as .params.
**kwargs – Additional dataclass attributes that should be .replace()-ed.
- Returns
An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.
Early Stopping#
- class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False)[source]#
Early stopping to avoid overfitting during training.
The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:
early_stop = EarlyStopping(min_delta=1e-3, patience=2) for epoch in range(1, num_epochs+1): rng, input_rng = jax.random.split(rng) optimizer, train_metrics = train_epoch( optimizer, train_ds, config.batch_size, epoch, input_rng) _, early_stop = early_stop.update(train_metrics['loss']) if early_stop.should_stop: print('Met early stopping criteria, breaking...') break
- min_delta#
Minimum delta between updates to be considered an improvement.
- Type
float
- patience#
Number of steps of no improvement before stopping.
- Type
int
- best_metric#
Current best metric value.
- Type
float
- patience_count#
Number of steps since last improving update.
- Type
int
- should_stop#
Whether the training loop should stop to avoid overfitting.
- Type
bool