Optimizer#
- class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt, graph=None)#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(2.3359997, dtype=float32) >>> grads = nnx.grad(loss_fn)(model) >>> _ = optimizer.update(model, grads) >>> loss_fn(model) Array(2.310461, dtype=float32)
- step#
An
OptStateVariablethat tracks the step count.
- tx#
An Optax gradient transformation.
- opt_state#
The Optax optimizer state.
- __init__(model, tx, *, wrt, graph=None)#
Instantiate the class and wrap the
Moduleand Optax gradient transformation. Instantiate the optimizer state to keep track ofVariabletypes specified inwrt. Set the step count to 0.- Parameters:
model – An NNX Module.
tx – An Optax gradient transformation.
wrt – filter to specify for which
Variable’s to keep track of in the optimizer state. These should be theVariable’s that you plan on updating; i.e. this argument value should match thewrtargument passed to thennx.gradcall that will generate the gradients that will be passed into thegradsargument of theupdate()method. The filter should match the filter used in nnx.grad.graph – If
True, uses graph-mode which supports the full NNX feature set including shared references. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. IfNone(default), the value is determined by the currentnnx.set_graph_modecontext.
- update(model, grads, /, **kwargs)#
Updates the optimizer state and model parameters given the gradients.
Example:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.count = nnx.Variable(jnp.array(0)) ... ... def __call__(self, x): ... self.count[...] += 1 ... return self.linear(x) ... >>> model = Model(rngs=nnx.Rngs(0)) ... >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) >>> grads = nnx.grad(loss_fn, argnums=0)( ... model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) >>> _ = optimizer.update(model, grads)
Note that internally this function calls
.tx.update()followed by a call tooptax.apply_updates()to updateparamsandopt_state.- Parameters:
grads – the gradients derived from
nnx.grad.**kwargs – additional keyword arguments passed to the tx.update, to support
GradientTransformationExtraArgs –
optax.scale_by_backtracking_linesearch. (such as) –
- Returns:
The updates PyTree containing the parameter updates applied to the model. This matches the structure of the model parameters filtered by
wrt.