Optimizer#

class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt, graph=None)#

Simple train state for the common case with a single Optax optimizer.

Example usage:

>>> import jax, jax.numpy as jnp
>>> from flax import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(2.3359997, dtype=float32)
>>> grads = nnx.grad(loss_fn)(model)
>>> _ = optimizer.update(model, grads)
>>> loss_fn(model)
Array(2.310461, dtype=float32)
step#

An OptState Variable that tracks the step count.

tx#

An Optax gradient transformation.

opt_state#

The Optax optimizer state.

__init__(model, tx, *, wrt, graph=None)#

Instantiate the class and wrap the Module and Optax gradient transformation. Instantiate the optimizer state to keep track of Variable types specified in wrt. Set the step count to 0.

Parameters:
  • model – An NNX Module.

  • tx – An Optax gradient transformation.

  • wrt – filter to specify for which Variable’s to keep track of in the optimizer state. These should be the Variable’s that you plan on updating; i.e. this argument value should match the wrt argument passed to the nnx.grad call that will generate the gradients that will be passed into the grads argument of the update() method. The filter should match the filter used in nnx.grad.

  • graph – If True, uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. If None (default), the value is determined by the current nnx.set_graph_mode context.

update(model, grads, /, **kwargs)#

Updates the optimizer state and model parameters given the gradients.

Example:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.count = nnx.Variable(jnp.array(0))
...
...   def __call__(self, x):
...     self.count[...] += 1
...     return self.linear(x)
...
>>> model = Model(rngs=nnx.Rngs(0))
...
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
>>> grads = nnx.grad(loss_fn, argnums=0)(
...   model, jnp.ones((1, 2)), jnp.ones((1, 3))
... )
>>> _ = optimizer.update(model, grads)

Note that internally this function calls .tx.update() followed by a call to optax.apply_updates() to update params and opt_state.

Parameters:
  • grads – the gradients derived from nnx.grad.

  • **kwargs – additional keyword arguments passed to the tx.update, to support

  • GradientTransformationExtraArgs

  • optax.scale_by_backtracking_linesearch. (such as) –

Returns:

The updates PyTree containing the parameter updates applied to the model. This matches the structure of the model parameters filtered by wrt.