flax.linen.grad#
- flax.linen.grad(fn, mdl, *primals, has_aux=False, reduce_axes=(), variables=True, rngs=True)[source]#
A limited, lifted equivalent of
jax.grad
.Note that for this convenience function, gradients are only calculated for the function inputs, and not with respect to any module variables. The target function must return a scalar-valued output. For a more general lifted vjp, see
nn.vjp
for the lifted vector-Jacobiam product.Example:
class LearnScale(nn.Module): @nn.compact def __call__(self, x, y): p = self.param('scale', nn.initializers.zeros_init(), ()) return p * x * y class Foo(nn.Module): @nn.compact def __call__(self, x, y): x_grad, y_grad = nn.grad( lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) return x_grad, y_grad
- Parameters
fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.
mdl – The module of which the variables will be differentiated.
*primals – A sequence of primal values at which the Jacobian of
fn
should be evaluated. The length ofprimals
should be equal to the number of positional parameters tofn
. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux – Optional, bool. Indicates whether
fn
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fn
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the grad will be per-example over named axes. For example, if'batch'
is a named batch axis,vjp(f, *args, reduce_axes=('batch',))
will create a grad function that sums over the batch whilegrad(f, *args)
will create a per-example grad.variables – variables collections that are available inside fn but do not receive a cotangent.
rngs – the prngs that are available inside fn.
- Returns
If
has_aux
isFalse
, returnsgrads
, wheregrads
are the gradients for the corresponding primals and do not include the gradients for module variables. Ifhas_aux
isTrue
, returns a(grads, aux)
tuple whereaux
is the auxiliary data returned byfn
.