flax.cursor package#

The Cursor API allows for mutability of pytrees. This API provides a more ergonomic solution to making partial-updates of deeply nested immutable data structures, compared to making many nested dataclasses.replace calls.

To illustrate, consider the example below:

>>> from flax.cursor import cursor
>>> import dataclasses
>>> from typing import Any

>>> @dataclasses.dataclass(frozen=True)
>>> class A:
...   x: Any

>>> a = A(A(A(A(A(A(A(0)))))))

To replace the int 0 using dataclasses.replace, we would have to write many nested calls:

>>> a2 = dataclasses.replace(
...   a,
...   x=dataclasses.replace(
...     a.x,
...     x=dataclasses.replace(
...       a.x.x,
...       x=dataclasses.replace(
...         a.x.x.x,
...         x=dataclasses.replace(
...           a.x.x.x.x,
...           x=dataclasses.replace(
...             a.x.x.x.x.x,
...             x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
...           ),
...         ),
...       ),
...     ),
...   ),
... )

The equivalent can be achieved much more simply using the Cursor API:

>>> a3 = cursor(a).x.x.x.x.x.x.x.set(1)
>>> assert a2 == a3

The Cursor object keeps tracks of changes made to it and when .build is called, generates a new object with the accumulated changes. Basic usage involves wrapping the object in a Cursor, making changes to the Cursor object and generating a new copy of the original object with the accumulated changes.

flax.cursor.cursor(obj)[source]#

Wrap Cursor over obj and return it. Changes can then be applied to the Cursor object in the following ways:

  • single-line change via the .set method

  • multiple changes, and then calling the .build method

  • multiple changes conditioned on the pytree path and node value via the .apply_update method, and then calling the .build method

.set example:

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

.build example:

>>> from flax.cursor import cursor

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

.apply_update example:

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> def update_fn(path, value):
...   '''Replace params with empty dictionary.'''
...   if 'params' in path:
...     return {}
...   return value

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params={'a': 1, 'b': 2},
...     tx=optax.adam(1e-3),
... )
>>> c = cursor(state)
>>> state2 = c.apply_update(update_fn).build()
>>> assert state2.params == {}
>>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged

If the underlying obj is a list or tuple, iterating over the Cursor object to get the child Cursors is also possible:

>>> from flax.cursor import cursor

>>> c = cursor(((1, 2), (3, 4)))
>>> for child_c in c:
...   child_c[1] *= -1
>>> assert c.build() == ((1, -2), (3, -4))

View the docstrings for each method to see more examples of their usage.

Parameters

obj – the object you want to wrap the Cursor in

Returns

A Cursor object wrapped around obj.

class flax.cursor.Cursor(obj, parent_key)[source]#
apply_update(update_fn)[source]#

Traverse the Cursor object and record conditional changes recursively via an update_fn. The changes are recorded in the Cursor object’s ._changes dictionary. To generate a copy of the original object with the accumulated changes, call the .build method after calling .apply_update.

The update_fn has a function signature of (str, Any) -> Any:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is the new value (either modified by the update_fn or same as the input value if the condition wasn’t fulfilled)

NOTES:

  • If the update_fn returns a modified value, this method will not recurse any further down that branch to record changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don’t need to look for further changes inside the dictionary, since the dictionary will be replaced anyways.

  • The is operator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if the update_fn modifies a mutable container (e.g. lists, dicts, etc.) and returns the same container, .apply_update will treat the returned value as unmodified as it contains the same id. To avoid this, return a copy of the modified value.

  • .apply_update WILL NOT call the update_fn to the value at the top-most level of the pytree (i.e. the root node). The update_fn will first be called on the root node’s children, and then the pytree traversal will continue recursively from there.

Example:

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params']

>>> def update_fn(path, value):
...   '''Multiply all dense kernel params by 2 and add 1.
...   Subtract the Dense_1 bias param by 1.'''
...   if 'kernel' in path:
...     return value * 2 + 1
...   elif 'Dense_1' in path and 'bias' in path:
...     return value - 1
...   return value

>>> c = cursor(params)
>>> new_params = c.apply_update(update_fn).build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.key(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
Parameters

update_fn – the function that will conditionally record changes to the Cursor object

Returns

The current Cursor object with the recorded conditional changes specified by the update_fn. To generate a copy of the original object with the accumulated changes, call the .build method after calling .apply_update.

build()[source]#

Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object.

NOTE: The new object is built bottom-up, the changes will be first applied to the leaf nodes, and then its parent, all the way up to the root.

Example:

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> new_fn = lambda x: x + 1
>>> c = cursor(state)
>>> c.params['b'][1] = 10
>>> c.apply_fn = new_fn
>>> modified_state = c.build()
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
>>> assert modified_state.apply_fn == new_fn
Returns

A copy of the original object with the accumulated changes.

find(cond_fn)[source]#

Traverse the Cursor object and return a child Cursor object that fulfill the conditions in the cond_fn. The cond_fn has a function signature of (str, Any) -> bool:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is a boolean, denoting whether to return the child Cursor object at this path

Raises a CursorFindError if no object or more than one object is found that fulfills the condition of the cond_fn. We raise an error because the user should always expect this method to return the only object whose corresponding key path and value fulfill the condition of the cond_fn.

NOTES:

  • If the cond_fn evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child node that fulfills the condition in cond_fn in a particular key path

  • .find WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The cond_fn will be evaluated recursively, starting at the root node’s children.

Example:

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find the second dense layer params.'''
...   return 'Dense_1' in path

>>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1)

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
...   else:
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> c = cursor(params)
>>> c2 = c.find(cond_fn)
>>> c2['kernel'] += 2
>>> c2['bias'] += 2
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   if layer == 'Dense_1':
...     assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all()
...   else:
...     assert (new_params[layer]['kernel'] == params[layer]['kernel']).all()
...     assert (new_params[layer]['bias'] == params[layer]['bias']).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
Parameters

cond_fn – the function that will conditionally find child Cursor objects

Returns

A child Cursor object that fulfills the condition in the cond_fn.

find_all(cond_fn)[source]#

Traverse the Cursor object and return a generator of child Cursor objects that fulfill the conditions in the cond_fn. The cond_fn has a function signature of (str, Any) -> bool:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is a boolean, denoting whether to return the child Cursor object at this path

NOTES:

  • If the cond_fn evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child nodes that fulfill the condition in cond_fn in a particular key path

  • .find_all WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The cond_fn will be evaluated recursively, starting at the root node’s children.

Example:

>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp

>>> class Model(nn.Module):
...   @nn.compact
...   def __call__(self, x):
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     x = nn.Dense(3)(x)
...     x = nn.relu(x)
...     return x

>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

>>> def cond_fn(path, value):
...   '''Find all dense layer params.'''
...   return 'Dense' in path

>>> c = cursor(params)
>>> for dense_params in c.find_all(cond_fn):
...   dense_params['bias'] += 1
>>> new_params = c.build()

>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
...   assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()

>>> assert jax.tree_util.tree_all(
...       jax.tree_util.tree_map(
...           lambda x, y: (x == y).all(),
...           params,
...           Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
...               'params'
...           ],
...       )
...   ) # make sure original params are unchanged
Parameters

cond_fn – the function that will conditionally find child Cursor objects

Returns

A generator of child Cursor objects that fulfill the condition in the cond_fn.

set(value)[source]#

Set a new value for an attribute, property, element or entry in the Cursor object and return a copy of the original object, containing the new set value.

Example:

>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax

>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

>>> state = train_state.TrainState.create(
...     apply_fn=lambda x: x,
...     params=dict_obj,
...     tx=optax.adam(1e-3),
... )
>>> modified_state = cursor(state).params['b'][1].set(10)
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
Parameters

value – the value used to set an attribute, property, element or entry in the Cursor object

Returns

A copy of the original object with the new set value.