Optimizer#
- class flax.experimental.nnx.optimizer.Optimizer(*args, **kwargs)#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax.experimental import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) >>> state.update(grads) >>> loss_fn(model) Array(1.6925814, dtype=float32)
Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics).
Example usage:
>>> class TrainState(nnx.Optimizer): ... def __init__(self, model, tx, metrics): ... self.metrics = metrics ... super().__init__(model, tx) ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) ... >>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.68612, dtype=float32)
For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.
- Parameters
model – An NNX Module.
tx – An Optax gradient transformation.
- update(grads)#
Updates
step
,params
,opt_state
and**kwargs
in return value.Note that internally this function calls
.tx.update()
followed by a call tooptax.apply_updates()
to updateparams
andopt_state
.- Parameters
grads – Gradients that have the same pytree structure as
.params
.**kwargs – Additional dataclass attributes that should be
.replace()
-ed.
- Returns
An updated instance of
self
withstep
incremented by one,params
andopt_state
updated by applyinggrads
, and additional attributes replaced as specified bykwargs
.