flax.linen.custom_vjp#

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#

Lifted version of jax.custom_vjp.

forward_fn and backward_fn together define a custom vjp for fn. The original fn will run in case a vjp (backward gradient) is not computed.

The forward_fn receives the same arguments as fn but is expected to return a tuple containing the output of fn(mdl, *args) and the residuals that are passed to backward_fn.

The backward_fn receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the input and variable tangents.

Note that the vjp function returned by nn.vjp can be passed as residual and used in the backward_fn. The scope is unavailable during the backward pass. If the module is required in backward_fn, a snapshot of the variables can be taken and returned as a residual in the forward_fn.

Example:

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    def f(mdl, x):
      return mdl(x)

    def fwd(mdl, x):
      return nn.vjp(f, mdl, x)

    def bwd(vjp_fn, y_t):
      input_t, params_t = vjp_fn(y_t)
      params_t = jax.tree_util.tree_map(jnp.sign, params_t)
      return input_t, params_t

    sign_grad = nn.custom_vjp(
        f, forward_fn=fwd, backward_fn=bwd)
    return sign_grad(nn.Dense(1), x).reshape(())

x = jnp.ones((2,))
variables = Foo().init(random.PRNGKey(0), x)
grad = jax.grad(Foo().apply)(variables, x)
Parameters
  • fn – The function to define a custom_vjp for.

  • forward_fn – A function with the same arguments as fn returning an tuple with the original output and the residuals that will be passsed to backward_fn.

  • backward_fn – arguments are passed as (*nondiff_args, residuals, tangents) The function should return a tuple containing the tangents for the input arguments (except the module and nondiff args) and the variable tangents for the collections specified by grad_vars.

  • grad_vars – The collections for which a vjp will be computed (default: “params”).

  • nondiff_argnums – arguments for which no vjp is computed.

Returns

A function with the same signature as fn with the custom vjp.