Optimizer

Optimizer#

class flax.experimental.nnx.optimizer.Optimizer(*args, **kwargs)#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import jax, jax.numpy as jnp
>>> from flax.experimental import nnx
>>> import optax

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))

>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))

>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> state = nnx.Optimizer(model, tx)

>>> loss_fn = lambda model: ((model(x)-y)**2).mean()
>>> loss_fn(state.model)
1.7055722
>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads)
>>> loss_fn(state.model)
1.6925814

Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics).

Example usage:

>>> class TrainState(nnx.Optimizer):
...   def __init__(self, model, tx, metrics):
...     self.metrics = metrics
...     super().__init__(model, tx)
...   def update(self, *, grads, **updates):
...     self.metrics.update(**updates)
...     super().update(grads)

>>> metrics = nnx.metrics.Average()
>>> state = TrainState(model, tx, metrics)

>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
1.6925814
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
1.68612

For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.

Parameters
  • model – An NNX Module.

  • tx – An Optax gradient transformation.

update(grads)#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.