flax.training package#

Checkpoints#

Checkpointing helper functions.

Handles saving and restoring optimizer checkpoints based on step-number or other numerical metric in filename. Cleans up older / worse-performing checkpoint files.

flax.training.checkpoints.save_checkpoint(ckpt_dir, target, step, prefix='checkpoint_', keep=1, overwrite=False, keep_every_n_steps=None, async_manager=None, gda_manager=None)[source]#

Save a checkpoint of the model.

Attempts to be pre-emption safe by writing to temporary before a final rename and cleanup of past files.

Parameters
  • ckpt_dir – str or pathlib-like path to store checkpoint files in.

  • target – serializable flax object, usually a flax optimizer.

  • step – int or float: training step number or other metric number.

  • prefix – str: checkpoint file name prefix.

  • keep – number of past checkpoint files to keep.

  • overwrite – overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False).

  • keep_every_n_steps – if defined, keep every checkpoints every n steps (in addition to keeping the last ‘keep’ checkpoints).

  • async_manager – if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly.

  • gda_manager – required if target contains a JAX GlobalDeviceArray. Will save the GDAs to a separate subdirectory with postfix “_gda” asynchronously. Same as async_manager, this will block subsequent saves.

Returns

Filename of saved checkpoint.

flax.training.checkpoints.latest_checkpoint(ckpt_dir, prefix='checkpoint_')[source]#

Retrieve the path of the latest checkpoint in a directory.

Parameters
  • ckpt_dir – str: directory of checkpoints to restore from.

  • prefix – str: name prefix of checkpoint files.

Returns

The latest checkpoint path or None if no checkpoints were found.

flax.training.checkpoints.restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True, gda_manager=None)[source]#

Restore last/best checkpoint from checkpoints in path.

Sorts the checkpoint files naturally, returning the highest-valued file, e.g.:

  • ckpt_1, ckpt_2, ckpt_3 --> ckpt_3

  • ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1

  • ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

Parameters
  • ckpt_dir – str: checkpoint file or directory of checkpoints to restore from.

  • target – matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is.

  • step – int: step number to load or None to load latest. If specified, ckpt_dir must be a directory.

  • prefix – str: name prefix of checkpoint files.

  • parallel – bool: whether to load seekable checkpoints in parallel, for speed.

  • gda_manager – required if checkpoint contains a JAX GlobalDeviceArray. Will read the GDAs from the separate subdirectory with postfix “_gda”.

Returns

Restored target updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in target unchanged. If a file path is specified and is not found, the passed-in target will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created.

flax.training.checkpoints.convert_pre_linen(params)[source]#

Converts a pre-Linen parameter pytree.

In pre-Linen API submodules were numbered incrementally, independent of the submodule class. With Linen this behavior has changed to keep separate submodule counts per module class.

Consider the following module:

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(1, 1)(x)
    x = nn.Dense(1)(x)
    return x

In pre-Linen the resulting params would have had the structure:

{'Conv_0': { ... }, 'Dense_1': { ... } }

With Linen the resulting params would instead have had the structure:

{'Conv_0': { ... }, 'Dense_0': { ... } }

To convert from pre-Linen format to Linen simply call:

params = convert_pre_linen(pre_linen_params)

Note that you can also use this utility to convert pre-Linen collections because they’re following the same module naming. Note though that collections were “flat” in pre-Linen and first need to be unflattened before they can be used with this function:

batch_stats = convert_pre_linen(flax.traverse_util.unflatten_dict({
    tuple(k.split('/')[1:]): v
    for k, v in pre_linen_model_state.as_dict().items()
}))

Then Linen variables can be defined from these converted collections:

variables = {'params': params, 'batch_stats': batch_stats}
Parameters

params – Parameter pytree in pre-Linen format. If the pytree is already in Linen format, then the returned pytree is unchanged (i.e. this function can safely be called on any loaded checkpoint for use with Linen).

Returns

Parameter pytree with Linen submodule naming.

Learning rate schedules#

Learning rate schedules used in FLAX image classification examples.

Note that with FLIP #1009 learning rate schedules in flax.training are effectively deprecated in favor of Optax schedules. Please refer to Optimizer Schedules for more information.

flax.training.lr_schedule.create_constant_learning_rate_schedule(base_learning_rate, steps_per_epoch, warmup_length=0.0)[source]#

Create a constant learning rate schedule with optional warmup.

Holds the learning rate constant. This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

flax.training.lr_schedule.create_stepped_learning_rate_schedule(base_learning_rate, steps_per_epoch, lr_sched_steps, warmup_length=0.0)[source]#

Create a stepped learning rate schedule with optional warmup.

A stepped learning rate schedule decreases the learning rate by specified amounts at specified epochs. The steps are given as the lr_sched_steps parameter. A common ImageNet schedule decays the learning rate by a factor of 0.1 at epochs 30, 60 and 80. This would be specified as:

[
  [30, 0.1],
  [60, 0.01],
  [80, 0.001]
]

This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • lr_sched_steps – the schedule as a list of steps, each of which is a [epoch, lr_factor] pair; the step occurs at epoch epoch and sets the learning rate to base_learning_rage * lr_factor

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

flax.training.lr_schedule.create_cosine_learning_rate_schedule(base_learning_rate, steps_per_epoch, halfcos_epochs, warmup_length=0.0)[source]#

Create a cosine learning rate schedule with optional warmup.

A cosine learning rate schedule modules the learning rate with half a cosine wave, gradually scaling it to 0 at the end of training.

This function also offers a learing rate warmup as per https://arxiv.org/abs/1706.02677, for the purpose of training with large mini-batches.

Parameters
  • base_learning_rate – the base learning rate

  • steps_per_epoch – the number of iterations per epoch

  • halfcos_epochs – the number of epochs to complete half a cosine wave; normally the number of epochs used for training

  • warmup_length – if > 0, the learning rate will be modulated by a warmup factor that will linearly ramp-up from 0 to 1 over the first warmup_length epochs

Returns

Function f(step) -> lr that computes the learning rate for a given step.

Train state#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#

Simple train state for the common case with a single Optax optimizer.

Synopsis:

state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
grad_fn = jax.grad(make_loss_fn(state.apply_fn))
for batch in data:
  grads = grad_fn(state.params, batch)
  state = state.apply_gradients(grads=grads)

Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections).

For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.

Parameters
  • step – Counter starts at 0 and is incremented by every call to .apply_gradients().

  • apply_fn – Usually set to model.apply(). Kept in this dataclass for convenience to have a shorter params list for the train_step() function in your training loop.

  • params – The parameters to be updated by tx and used by apply_fn.

  • tx – An Optax gradient transformation.

  • opt_state – The state for tx.

apply_gradients(*, grads, **kwargs)[source]#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.

classmethod create(*, apply_fn, params, tx, **kwargs)[source]#

Creates a new instance with step=0 and initialized opt_state.

Early Stopping#

class flax.training.early_stopping.EarlyStopping(min_delta=0, patience=0, best_metric=inf, patience_count=0, should_stop=False)[source]#

Early stopping to avoid overfitting during training.

The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:

early_stop = EarlyStopping(min_delta=1e-3, patience=2)
for epoch in range(1, num_epochs+1):
  rng, input_rng = jax.random.split(rng)
  optimizer, train_metrics = train_epoch(
      optimizer, train_ds, config.batch_size, epoch, input_rng)
  _, early_stop = early_stop.update(train_metrics['loss'])
  if early_stop.should_stop:
    print('Met early stopping criteria, breaking...')
    break
min_delta#

Minimum delta between updates to be considered an improvement.

Type

float

patience#

Number of steps of no improvement before stopping.

Type

int

best_metric#

Current best metric value.

Type

float

patience_count#

Number of steps since last improving update.

Type

int

should_stop#

Whether the training loop should stop to avoid overfitting.

Type

bool

update(metric)[source]#

Update the state based on metric.

Returns

A pair (has_improved, early_stop), where has_improved is True when there was an improvement greater than min_delta from the previous best_metric and early_stop is the updated EarlyStop object.