flax.linen.custom_vjp

Contents

flax.linen.custom_vjp#

flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#

Lifted version of jax.custom_vjp.

forward_fn and backward_fn together define a custom vjp for fn. The original fn will run in case a vjp (backward gradient) is not computed.

The forward_fn receives the same arguments as fn but is expected to return a tuple containing the output of fn(mdl, *args) and the residuals that are passed to backward_fn.

The backward_fn receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents.

Note that the vjp function returned by nn.vjp can be passed as residual and used in the backward_fn. The scope is unavailable during the backward pass. If the module is required in backward_fn, a snapshot of the variables can be taken and returned as a residual in the forward_fn.

Example:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     def f(mdl, x):
...       return mdl(x)
...
...     def fwd(mdl, x):
...       return nn.vjp(f, mdl, x)
...
...     def bwd(vjp_fn, y_t):
...       params_t, *inputs_t = vjp_fn(y_t)
...       params_t = jax.tree_util.tree_map(jnp.sign, params_t)
...       return (params_t, *inputs_t)
...
...     sign_grad = nn.custom_vjp(
...         f, forward_fn=fwd, backward_fn=bwd)
...     return sign_grad(nn.Dense(1), x).reshape(())

>>> x = jnp.ones((2,))
>>> variables = Foo().init(jax.random.key(0), x)
>>> grad = jax.grad(Foo().apply)(variables, x)
Parameters
  • fn – The function to define a custom_vjp for.

  • forward_fn – A function with the same arguments as fn returning an tuple with the original output and the residuals that will be passsed to backward_fn.

  • backward_fn – arguments are passed as (*nondiff_args, residuals, tangents) The function should return a tuple containing the tangents for the variable in the collections specified by grad_vars and the input arguments (except the module and nondiff args).

  • grad_vars – The collections for which a vjp will be computed (default: “params”).

  • nondiff_argnums – arguments for which no vjp is computed.

Returns

A function with the same signature as fn with the custom vjp.