Optimizer#
- class flax.nnx.optimizer.Optimizer(*args, **kwargs)#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads) >>> loss_fn(model) Array(1.6925814, dtype=float32)
Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics).
Example usage:
>>> class TrainState(nnx.Optimizer): ... def __init__(self, model, tx, metrics): ... self.metrics = metrics ... super().__init__(model, tx) ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) ... >>> grads = nnx.grad(loss_fn)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() Array(1.68612, dtype=float32)
For more exotic usecases (e.g. multiple optimizers) it’s probably best to fork the class and modify it.
- step#
An
OptState
Variable
that tracks the step count.
- model#
The wrapped
Module
.
- tx#
An Optax gradient transformation.
- opt_state#
The Optax optimizer state.
- __init__(model, tx, wrt=<class 'flax.nnx.variables.Param'>)#
Instantiate the class and wrap the
Module
and Optax gradient transformation. Instantiate the optimizer state to keep track ofVariable
types specified inwrt
. Set the step count to 0.- Parameters
model – An NNX Module.
tx – An Optax gradient transformation.
wrt – optional argument to filter for which
Variable
’s to keep track of in the optimizer state. These should be theVariable
’s that you plan on updating; i.e. this argument value should match thewrt
argument passed to thennx.grad
call that will generate the gradients that will be passed into thegrads
argument of theupdate()
method.
- update(grads)#
Updates
step
,params
,opt_state
and**kwargs
in return value. Thegrads
must be derived fromnnx.grad(..., wrt=self.wrt)
, where the gradients are with respect to the sameVariable
types as defined inself.wrt
during instantiation of thisOptimizer
. For example:>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(model)) State({ 'custom_variable': VariableState( type=CustomVariable, value=(1, 3) ), 'linear': { 'bias': VariableState( type=Param, value=(3,) ), 'kernel': VariableState( type=Param, value=(2, 3) ) } }) >>> # update: >>> # - only Linear layer parameters >>> # - only CustomVariable parameters >>> # - both Linear layer and CustomVariable parameters >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> for variable in (nnx.Param, CustomVariable, (nnx.Param, CustomVariable)): ... # make sure `wrt` arguments match for `nnx.Optimizer` and `nnx.grad` ... state = nnx.Optimizer(model, optax.adam(1e-3), wrt=variable) ... grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))( ... state.model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) ... state.update(grads=grads)
Note that internally this function calls
.tx.update()
followed by a call tooptax.apply_updates()
to updateparams
andopt_state
.- Parameters
grads – the gradients derived from
nnx.grad
.