- flax.linen.custom_vjp(fn, forward_fn, backward_fn, grad_vars='params', nondiff_argnums=())[source]#
Lifted version of jax.custom_vjp.
forward_fn and backward_fn together define a custom vjp for fn. The original fn will run in case a vjp (backward gradient) is not computed.
The forward_fn receives the same arguments as fn but is expected to return a tuple containing the output of fn(mdl, *args) and the residuals that are passed to backward_fn.
The backward_fn receives the nondiff arguments, residuals, and the output tangents. It should return a tuple containing the variable and input tangents.
Note that the vjp function returned by nn.vjp can be passed as residual and used in the backward_fn. The scope is unavailable during the backward pass. If the module is required in backward_fn, a snapshot of the variables can be taken and returned as a residual in the forward_fn.
class Foo(nn.Module): @nn.compact def __call__(self, x): def f(mdl, x): return mdl(x) def fwd(mdl, x): return nn.vjp(f, mdl, x) def bwd(vjp_fn, y_t): params_t, *inputs_t = vjp_fn(y_t) params_t = jax.tree_util.tree_map(jnp.sign, params_t) return (params_t, *inputs_t) sign_grad = nn.custom_vjp( f, forward_fn=fwd, backward_fn=bwd) return sign_grad(nn.Dense(1), x).reshape(()) x = jnp.ones((2,)) variables = Foo().init(random.PRNGKey(0), x) grad = jax.grad(Foo().apply)(variables, x)
fn – The function to define a custom_vjp for.
forward_fn – A function with the same arguments as
fnreturning an tuple with the original output and the residuals that will be passsed to
backward_fn – arguments are passed as
(*nondiff_args, residuals, tangents)The function should return a tuple containing the tangents for the variable in the collections specified by grad_vars and the input arguments (except the module and nondiff args).
grad_vars – The collections for which a vjp will be computed (default: “params”).
nondiff_argnums – arguments for which no vjp is computed.
A function with the same signature as fn with the custom vjp.