flax.training package#

Train state#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[source]#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(1.8136346, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(1.8079796, dtype=float32)

Note that you can easily extend this dataclass by subclassing it for storing additional data (e.g. additional variable collections).

For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.

Parameters
  • step – Counter starts at 0 and is incremented by every call to .apply_gradients().

  • apply_fn – Usually set to model.apply(). Kept in this dataclass for convenience to have a shorter params list for the train_step() function in your training loop.

  • params – The parameters to be updated by tx and used by apply_fn.

  • tx – An Optax gradient transformation.

  • opt_state – The state for tx.

apply_gradients(*, grads, **kwargs)[source]#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.

classmethod create(*, apply_fn, params, tx, **kwargs)[source]#

Creates a new instance with step=0 and initialized opt_state.