flax.linen.jvp#
- flax.linen.jvp(fn, mdl, primals, tangents, variable_tangents, variables=True, rngs=True)[source]#
A lifted version of
jax.jvp
.See
jax.jvp
for the unlifted Jacobian-vector product (forward gradient).Note that no tangents are returned for variables. When variable tangents are required their value should be returned explicitly by fn using Module.variables:
class LearnScale(nn.Module): @nn.compact def __call__(self, x): p = self.param('test', nn.initializers._init(), ()) return p * x class Foo(nn.Module): @nn.compact def __call__(self, x): scale = LearnScale() vars_t = jax.tree_util.tree_map(jnp.ones_like, scale.variables.get('params', {})) _, out_t = nn.jvp( lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),), variable_tangents={'params': vars_t}) return out_t
Example:
def learn_scale(scope, x): p = scope.param('scale', nn.initializers.zeros_init(), ()) return p * x def f(scope, x): vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) x, out_t = lift.jvp( learn_scale, scope, (x,), (jnp.zeros_like(x),), variable_tangents={'params': vars_t}) return out_t
- Parameters
fn – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. It will receive the scope and primals as arguments.
mdl – The module of which the variables will be differentiated.
primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.variable_tangents – A dict or PyTree fo dicts with the same structure as scopes. Each entry in the dict specifies the tangents for a variable collection. Not specifying a collection in variable_tangents is equivalent to passing a zero vector as the tangent.
variables – other variables collections that are available in fn but do not receive a tangent.
rngs – the prngs that are available inside fn.
- Returns
A
(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobian-vector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
.