flax.traverse_util package#

A utility for traversing immutable datastructures.

A Traversal can be used to iterate and update complex data structures. Traversals take in an object and return a subset of its contents. For example, a Traversal could select an attribute of an object:

x = Foo(foo=1)
traverse_util.TraverseAttr('foo').iterate(x) # [1]

More complex traversals can be constructed using composition. It is often useful to start from the identity traversal and use a method chain to construct the intended Traversal:

data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
traversal = traverse_util.t_identity.each()['foo']
traversal.iterate(data) # [1, 3]

Traversals can also be used to make changes using the update method:

data = {'foo': Foo(bar=2)}
traversal = traverse_util.t_identity['foo'].bar
traversal.update(lambda x: x + x, data) # {'foo': Foo(bar=4)}

Traversals never mutate the original data. Therefore, an update essentially returns a copy of the data including the provided updates.

Traversal objects#

class flax.traverse_util.Traversal(*args, **kwargs)[source]#

Base class for all traversals.

compose(other)[source]#

Compose two traversals.

each()[source]#

Traverse each item in the selected containers.

filter(fn)[source]#

Filter the selected values.

abstract iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

merge(*traversals)[source]#

Compose an arbitrary number of traversals and merge the results.

set(values, inputs)[source]#

Overrides the values selected by the Traversal.

Parameters
  • values – a list containing the new values.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

tree()[source]#

Traverse each item in a pytree.

abstract update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseId(*args, **kwargs)[source]#

The identity Traversal.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseMerge(*args, **kwargs)[source]#

Merges the selection from a set of traversals.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseCompose(*args, **kwargs)[source]#

Compose two traversals.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseFilter(*args, **kwargs)[source]#

Filter selected values based on a predicate.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseAttr(*args, **kwargs)[source]#

Traverse the attribute of an object.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseItem(*args, **kwargs)[source]#

Traverse the item of an object.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseEach(*args, **kwargs)[source]#

Traverse each item of a container.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

class flax.traverse_util.TraverseTree(*args, **kwargs)[source]#

Traverse every item in a pytree.

iterate(inputs)[source]#

Iterate over the values selected by this Traversal.

Parameters

inputs – the object that should be traversed.

Returns

An iterator over the traversed values.

update(fn, inputs)[source]#

Update the focused items.

Parameters
  • fn – the callback function that maps each traversed item to its updated value.

  • inputs – the object that should be traversed.

Returns

A new object with the updated values.

Dict utils#

flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#

Flatten a nested dictionary.

The nested keys are flattened to a tuple. See unflatten_dict on how to restore the nested dictionary structure.

Example:

xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
flat_xs = flatten_dict(xs)
print(flat_xs)
# {
#   ('foo',): 1,
#   ('bar', 'a'): 2,
# }

Note that empty dictionaries are ignored and will not be restored by unflatten_dict.

Parameters
  • xs – a nested dictionary

  • keep_empty_nodes – replaces empty dictionaries with traverse_util.empty_node. This must be set to True for unflatten_dict to correctly restore empty dictionaries.

  • is_leaf – an optional function that takes the next nested dictionary and nested keys and returns True if the nested dictionary is a leaf (i.e., should not be flattened further).

  • sep – if specified, then the keys of the returned dictionary will be sep-joined strings (if None, then keys will be tuples).

Returns

The flattened dictionary.

flax.traverse_util.unflatten_dict(xs, sep=None)[source]#

Unflatten a dictionary.

See flatten_dict

Example:

flat_xs = {
  ('foo',): 1,
  ('bar', 'a'): 2,
}
xs = unflatten_dict(flat_xs)
print(xs)
# {
#   'foo': 1
#   'bar': {'a': 2}
# }
Parameters
  • xs – a flattened dictionary

  • sep – separator (same as used with flatten_dict()).

Returns

The nested dictionary.

Model parameter traversal#

class flax.traverse_util.ModelParamTraversal(*args, **kwargs)[source]#

Select model parameters using a name filter.

This traversal operates on a nested dictionary of parameters and selects a subset based on the filter_fn argument.

See flax.optim.MultiOptimizer for an example of how to use ModelParamTraversal to update subsets of the parameter tree with a specific optimizer.

__init__(filter_fn)[source]#

Constructor a new ModelParamTraversal.

Parameters

filter_fn – a function that takes a parameter’s full name and its value and returns whether this parameter should be selected or not. The name of a parameter is determined by the module hierarchy and the parameter name (for example: ‘/module/sub_module/parameter_name’).