flax.linen.jvp

Contents

flax.linen.jvp#

flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[source]#

A lifted version of jax.jvp.

See jax.jvp for the unlifted Jacobian-vector product (forward gradient).

Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by fn using Module.variables:

>>> import flax.linen as nn
>>> import jax.numpy as jnp

>>> class LearnScale(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     p = self.param('test', nn.initializers._init(), ())
...     return p * x

>>> class Foo(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     scale = LearnScale()
...     vars_t = jax.tree_util.tree_map(jnp.ones_like,
...                                     scale.variables.get('params', {}))
...     _, out_t = nn.jvp(
...         lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),),
...         variable_tangents={'params': vars_t})
...     return out_t

Example:

>>> def learn_scale(scope, x):
...   p = scope.param('scale', nn.initializers.zeros_init(), ())
...   return p * x

>>> def f(scope, x):
...   vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {}))
...   x, out_t = lift.jvp(
...       learn_scale, scope, (x,), (jnp.zeros_like(x),),
...       variable_tangents={'params': vars_t})
...   return out_t
Parameters
  • fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.

  • mdl – The module of which the variables will be differentiated.

  • primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters of fun.

  • tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals.

  • variable_tangents – A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent.

  • variables – other variables collections that are available in fn but do not receive a tangent.

  • rngs – the prngs that are available inside fn.

Returns

A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.