Optimizer

Optimizer#

class flax.experimental.nnx.optimizer.Optimizer(*args, **kwargs)#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import jax, jax.numpy as jnp
>>> from flax.experimental import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> state = nnx.Optimizer(model, tx)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(1.7055722, dtype=float32)
>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads)
>>> loss_fn(model)
Array(1.6925814, dtype=float32)

Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics).

Example usage:

>>> class TrainState(nnx.Optimizer):
...   def __init__(self, model, tx, metrics):
...     self.metrics = metrics
...     super().__init__(model, tx)
...   def update(self, *, grads, **updates):
...     self.metrics.update(**updates)
...     super().update(grads)
...
>>> metrics = nnx.metrics.Average()
>>> state = TrainState(model, tx, metrics)
...
>>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.6925814, dtype=float32)
>>> state.update(grads=grads, values=loss_fn(state.model))
>>> state.metrics.compute()
Array(1.68612, dtype=float32)

For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.

Parameters
  • model – An NNX Module.

  • tx – An Optax gradient transformation.

update(grads)#

Updates step, params, opt_state and **kwargs in return value.

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters
  • grads – Gradients that have the same pytree structure as .params.

  • **kwargs – Additional dataclass attributes that should be .replace()-ed.

Returns

An updated instance of self with step incremented by one, params and opt_state updated by applying grads, and additional attributes replaced as specified by kwargs.