flax.linen.grad#

flax.linen.grad(fn, mdl, *primals, has_aux=False, reduce_axes=(), variables=True, rngs=True)[source]#

A limited, lifted equivalent of jax.grad.

Note that for this convenience function, gradients are only calculated for the function inputs, and not with respect to any module variables. The target function must return a scalar-valued output. For a more general lifted vjp, see nn.vjp for the lifted vector-Jacobiam product.

Example:

class LearnScale(nn.Module):
  @nn.compact
  def __call__(self, x, y):
    p = self.param('scale', nn.initializers.zeros_init(), ())
    return p * x * y

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x, y):
    x_grad, y_grad = nn.grad(
        lambda mdl, x, y: mdl(x, y), LearnScale(), x, y)
    return x_grad, y_grad
Parameters
  • fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.

  • mdl – The module of which the variables will be differentiated.

  • *primals – A sequence of primal values at which the Jacobian of fn should be evaluated. The length of primals should be equal to the number of positional parameters to fn. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.

  • has_aux – Optional, bool. Indicates whether fn returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • reduce_axes – Optional, tuple of axis names. If an axis is listed here, and fn implicitly broadcasts a value over that axis, the backward pass will perform a psum of the corresponding gradient. Otherwise, the grad will be per-example over named axes. For example, if 'batch' is a named batch axis, vjp(f, *args, reduce_axes=('batch',)) will create a grad function that sums over the batch while grad(f, *args) will create a per-example grad.

  • variables – variables collections that are available inside fn but do not receive a cotangent.

  • rngs – the prngs that are available inside fn.

Returns

If has_aux is False, returns grads, where grads are the gradients for the corresponding primals and do not include the gradients for module variables. If has_aux is True, returns a (grads, aux) tuple where aux is the auxiliary data returned by fn.