Using Filters#
Attention: This page relates to the new Flax NNX API.
Filters are used extensively in Flax NNX as a way to create State
groups in APIs
such as nnx.split
, nnx.state
, and many of the Flax NNX transforms. For example:
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
Here nnx.Param
and nnx.BatchStat
are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions:
What is a Filter?
Why are types, such as
Param
orBatchStat
, Filters?How is
State
grouped / filtered?
The Filter Protocol#
In general Filter are predicate functions of the form:
(path: tuple[Key, ...], value: Any) -> bool
where Key
is a hashable and comparable type, path
is a tuple of Key
s representing the path to the value in a nested structure, and value
is the value at the path. The function returns True
if the value should be included in the group and False
otherwise.
Types are obviously not functions of this form, so the reason why they are treated as Filters
is because, as we will see next, types and some other literals are converted to predicates. For example,
Param
is roughly converted to a predicate like this:
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param) or (
hasattr(value, 'type') and issubclass(value.type, nnx.Param)
)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
Such function matches any value that is an instance of Param
or any value that has a
type
attribute that is a subclass of Param
. Internally Flax NNX uses OfType
which
defines a callable of this form for a given type:
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
print(f'{is_param((), nnx.VariableState(type=nnx.Param, value=0)) = }')
is_param((), nnx.Param(0)) = True
is_param((), nnx.VariableState(type=nnx.Param, value=0)) = True
The Filter DSL#
To avoid users having to create these functions, Flax NNX exposes a small DSL, formalized
as the nnx.filterlib.Filter
type, which lets users pass types, booleans, ellipsis,
tuples/lists, etc, and converts them to the appropriate predicate internally.
Here is a list of all the callable Filters included in Flax NNX and their DSL literals (when available):
Literal |
Callable |
Description |
---|---|---|
|
|
Matches all values |
|
|
Matches no values |
|
|
Matches values that are instances of |
|
Matches values that have an associated |
|
|
|
Matches values that have string |
|
|
Matches values that match any of the inner |
|
Matches values that match all of the inner |
|
|
Matches values that do not match the inner |
Let see the DSL in action with a nnx.vmap
example. Lets say we want vectorized all parameters
and dropout
Rng(Keys|Counts) on the 0th axis, and broadcasted the rest. To do so we can
use the following filters:
from functools import partial
@partial(nnx.vmap, in_axes=(None, 0), state_axes={(nnx.Param, 'dropout'): 0, ...: None})
def forward(model, x):
...
Here (nnx.Param, 'dropout')
expands to Any(OfType(nnx.Param), WithTag('dropout'))
and ...
expands to Everything()
.
If you wish to manually convert literal into a predicate to can use nnx.filterlib.to_predicate
:
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.nnx.variables.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.nnx.variables.Param'>), WithTag('dropout'))
Grouping States#
With the knowledge of Filters at hand, let’s see how nnx.split
is roughly implemented. Key ideas:
Use
nnx.graph.flatten
to get theGraphDef
andState
representation of the node.Convert all the filters to predicates.
Use
State.flat_state
to get the flat representation of the state.Traverse all the
(path, value)
pairs in the flat state and group them according to the predicates.Use
State.from_flat_state
to convert the flat states to nestedState
s.
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state, _ = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state.flat_state().items():
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# lets test it...
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
batch_stats = State({
'b': VariableState(
type=BatchStat,
value=True
)
})
One very important thing to note is that filtering is order-dependent. The first filter that
matches a value will keep it, therefore you should place more specific filters before more general
filters. For example if we create a SpecialParam
type that is a subclass of Param
, and a Bar
object that contains both types of parameters, if we try to split the Param
s before the
SpecialParam
s then all the values will be placed in the Param
group and the SpecialParam
group
will be empty because all SpecialParam
s are also Param
s:
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
),
'b': VariableState(
type=SpecialParam,
value=0
)
})
special_params = State({})
Reversing the order will make sure that the SpecialParam
are captured first
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': VariableState(
type=Param,
value=0
)
})
special_params = State({
'b': VariableState(
type=SpecialParam,
value=0
)
})