flax.linen.cond#
- flax.linen.cond(pred, true_fun, false_fun, mdl, *operands, variables=True, rngs=True)[source]#
Lifted version of
jax.lax.cond
.The returned values from
true_fun
andfalse_fun
must have the same Pytree structure, shapes, and dtypes. The variables created or updated inside the branches must also have the same structure. Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch causes the parameter structure to be different.Example:
class CondExample(nn.Module): @nn.compact def __call__(self, x, pred): self.variable('state', 'true_count', lambda: 0) self.variable('state', 'false_count', lambda: 0) def true_fn(mdl, x): mdl.variable('state', 'true_count').value += 1 return nn.Dense(2, name='dense')(x) def false_fn(mdl, x): mdl.variable('state', 'false_count').value += 1 return -nn.Dense(2, name='dense')(x) return nn.cond(pred, true_fn, false_fn, self, x)
- Parameters
pred – determines if true_fun or false_fun is evaluated.
true_fun – The function evalauted when
pred
is True. The signature is (module, *operands) -> T.false_fun – The function evalauted when
pred
is False. The signature is (module, *operands) -> T.mdl – A Module target to pass.
*operands – The arguments passed to
true_fun
andfalse_fun
variables – The variable collections passed to the conditional branches (default: all)
rngs – The PRNG sequences passed to the conditionals (default: all)
- Returns
The result of the evaluated branch (
true_fun
orfalse_fun
).