Using Filters, grouping NNX variables#
Flax NNX uses Filter
s extensively as a way to create nnx.State
groups in APIs, such as nnx.split
, nnx.state()
, and many of the Flax NNX transformations (transforms).
In this guide you will learn how to:
Use
Filter
s to group Flax NNX variables and states into subgroups;Understand relationships between types, such as
nnx.Param
ornnx.BatchStat
, andFilter
s;Express your
Filter
s flexibly withnnx.filterlib.Filter
language.
In the following example nnx.Param
and nnx.BatchStat
are used as Filter
s to split the model into two groups: one with the parameters and the other with the batch statistics:
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
Let’s dive deeper into Filter
s.
The Filter
Protocol#
In general, Flax Filter
s are predicate functions of the form:
(path: tuple[Key, ...], value: Any) -> bool
where:
Key
is a hashable and comparable type;path
is a tuple ofKey
s representing the path to the value in a nested structure; andvalue
is the value at the path.
The function returns True
if the value should be included in the group, and False
otherwise.
Types are not functions of this form. They are treated as Filter
s because, as you will learn in the next section, types and some other literals are converted to predicates. For example, nnx.Param
is roughly converted to a predicate like this:
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
Such function matches any value that is an instance of nnx.Param
or any value that has a type
attribute that is a subclass of nnx.Param
. Internally Flax NNX uses OfType
which defines a callable of this form for a given type:
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
The Filter
DSL#
Flax NNX exposes a small domain specific language (DSL), formalized as the nnx.filterlib.Filter
type. This means users don’t have to create functions like in the previous section.
Here is a list of all the callable Filter
s included in Flax NNX, and their corresponding DSL literals (when available):
Literal |
Callable |
Description |
---|---|---|
|
|
Matches all values |
|
|
Matches no values |
|
|
Matches values that are instances of |
|
Matches values that have an associated |
|
|
|
Matches values that have string |
|
|
Matches values that match any of the inner |
|
Matches values that match all of the inner |
|
|
Matches values that do not match the inner |
Let’s check out the DSL in action by using nnx.vmap
as an example. Consider the following:
You want to vectorize all parameters;
Apply
'dropout'
Rng(Keys|Counts)
on the0
th axis; andBroadcast the rest.
To do this, you can use the following Filter
s to define a nnx.StateAxes
object that you can pass to nnx.vmap
’s in_axes
to specify how the model
’s various sub-states should be vectorized:
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
Here (nnx.Param, 'dropout')
expands to Any(OfType(nnx.Param), WithTag('dropout'))
and ...
expands to Everything()
.
If you wish to manually convert literal into a predicate, you can use nnx.filterlib.to_predicate
:
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))
Grouping State
s#
With the knowledge of Filter
s from previous sections at hand, let’s learn how to roughly implement nnx.split
. Here are the key ideas:
Use
nnx.graph.flatten
to get theGraphDef
andnnx.State
representation of the node.Convert all the
Filter
s to predicates.Use
State.flat_state
to get the flat representation of the state.Traverse all the
(path, value)
pairs in the flat state and group them according to the predicates.Use
State.from_flat_state
to convert the flat states to nestednnx.State
s.
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# Let's test it.
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
Note:* It’s very important to know that filtering is order-dependent. The first Filter
that matches a value will keep it, and therefore you should place more specific Filter
s before more general Filter
s.
For example, as demonstrated below, if you:
Create a
SpecialParam
type that is a subclass ofnnx.Param
, and aBar
object (subclassingnnx.Module
) that contains both types of parameters; andTry to split the
nnx.Param
s before theSpecialParam
s
then all the values will be placed in the nnx.Param
group, and the SpecialParam
group will be empty because all SpecialParam
s are also nnx.Param
s:
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
),
'b': VariableState(
type=SpecialParam,
value=0
)
})
special_params = State({})
And reversing the order will ensure that the SpecialParam
are captured first:
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
special_params = State({
'b': VariableState(
type=SpecialParam,
value=0
)
})