flax.cursor package#

The Cursor API allows for mutability of pytrees. This API provides a more ergonomic solution to making partial-updates of deeply nested immutable data structures, compared to making many nested dataclasses.replace calls.

To illustrate, consider the example below:

from flax.cursor import cursor
import dataclasses
from typing import Any

@dataclasses.dataclass
class A:
  x: Any

a = A(A(A(A(A(A(A(0)))))))

To replace the int 0 using dataclasses.replace, we would have to write many nested calls:

a2 = dataclasses.replace(
  a,
  x=dataclasses.replace(
    a.x,
    x=dataclasses.replace(
      a.x.x,
      x=dataclasses.replace(
        a.x.x.x,
        x=dataclasses.replace(
          a.x.x.x.x,
          x=dataclasses.replace(
            a.x.x.x.x.x,
            x=dataclasses.replace(a.x.x.x.x.x.x, x=1),
          ),
        ),
      ),
    ),
  ),
)

The equivalent can be achieved much more simply using the Cursor API:

a3 = cursor(a).x.x.x.x.x.x.x.set(1)
assert a2 == a3

The Cursor object keeps tracks of changes made to it and when .build is called, generates a new object with the accumulated changes. Basic usage involves wrapping the object in a Cursor, making changes to the Cursor object and generating a new copy of the original object with the accumulated changes.

flax.cursor.cursor(obj)[source]#

Wrap Cursor over obj and return it. Changes can then be applied to the Cursor object in the following ways:

  • single-line change via the .set method

  • multiple changes, and then calling the .build method

  • multiple changes conditioned on the pytree path and node value via the .apply_update method, and then calling the .build method

.set example:

from flax.cursor import cursor

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

.build example:

from flax.cursor import cursor

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
c = cursor(dict_obj)
c['b'][0] = 10
c['a'] = (100, 200)
modified_dict_obj = c.build()
assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

.apply_update example:

from flax.cursor import cursor
from flax.training import train_state
import optax

def update_fn(path, value):
  '''Replace params with empty dictionary.'''
  if 'params' in path:
    return {}
  return value

state = train_state.TrainState.create(
    apply_fn=lambda x: x,
    params={'a': 1, 'b': 2},
    tx=optax.adam(1e-3),
)
c = cursor(state)
state2 = c.apply_update(update_fn).build()
assert state2.params == {}
assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged

If the underlying obj is a list or tuple, iterating over the Cursor object to get the child Cursors is also possible:

from flax.cursor import cursor

c = cursor(((1, 2), (3, 4)))
for child_c in c:
  child_c[1] *= -1
assert c.build() == ((1, -2), (3, -4))

View the docstrings for each method to see more examples of their usage.

Parameters

obj – the object you want to wrap the Cursor in

Returns

A Cursor object wrapped around obj.

class flax.cursor.Cursor(obj, parent_key)[source]#
apply_update(update_fn)[source]#

Traverse the Cursor object and record conditional changes recursively via an update_fn. The changes are recorded in the Cursor object’s ._changes dictionary. To generate a copy of the original object with the accumulated changes, call the .build method after calling .apply_update.

The update_fn has a function signature of (str, Any) -> Any:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is the new value (either modified by the update_fn or same as the input value if the condition wasn’t fulfilled)

NOTES:

  • If the update_fn returns a modified value, this method will not recurse any further down that branch to record changes. For example, if we intend to replace an attribute that points to a dictionary with an int, we don’t need to look for further changes inside the dictionary, since the dictionary will be replaced anyways.

  • The is operator is used to determine whether the return value is modified (by comparing it to the input value). Therefore if the update_fn modifies a mutable container (e.g. lists, dicts, etc.) and returns the same container, .apply_update will treat the returned value as unmodified as it contains the same id. To avoid this, return a copy of the modified value.

  • .apply_update WILL NOT call the update_fn to the value at the top-most level of the pytree (i.e. the root node). The update_fn will first be called on the root node’s children, and then the pytree traversal will continue recursively from there.

Example:

import flax.linen as nn
from flax.cursor import cursor
import jax
import jax.numpy as jnp

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    return x

params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params']

def update_fn(path, value):
  '''Multiply all dense kernel params by 2 and add 1.
  Subtract the Dense_1 bias param by 1.'''
  if 'kernel' in path:
    return value * 2 + 1
  elif 'Dense_1' in path and 'bias' in path:
    return value - 1
  return value

c = cursor(params)
new_params = c.apply_update(update_fn).build()
for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
  assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
  if layer == 'Dense_1':
    assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all()
  else:
    assert (new_params[layer]['bias'] == params[layer]['bias']).all()

assert jax.tree_util.tree_all(
      jax.tree_util.tree_map(
          lambda x, y: (x == y).all(),
          params,
          Model().init(jax.random.key(0), jnp.empty((1, 2)))[
              'params'
          ],
      )
  ) # make sure original params are unchanged
Parameters

update_fn – the function that will conditionally record changes to the Cursor object

Returns

The current Cursor object with the recorded conditional changes specified by the update_fn. To generate a copy of the original object with the accumulated changes, call the .build method after calling .apply_update.

build()[source]#

Create and return a copy of the original object with accumulated changes. This method is to be called after making changes to the Cursor object.

NOTE: The new object is built bottom-up, the changes will be first applied to the leaf nodes, and then its parent, all the way up to the root.

Example:

from flax.cursor import cursor
from flax.training import train_state
import optax

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
c = cursor(dict_obj)
c['b'][0] = 10
c['a'] = (100, 200)
modified_dict_obj = c.build()
assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}

state = train_state.TrainState.create(
    apply_fn=lambda x: x,
    params=dict_obj,
    tx=optax.adam(1e-3),
)
new_fn = lambda x: x + 1
c = cursor(state)
c.params['b'][1] = 10
c.apply_fn = new_fn
modified_state = c.build()
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
assert modified_state.apply_fn == new_fn
Returns

A copy of the original object with the accumulated changes.

find(cond_fn)[source]#

Traverse the Cursor object and return a child Cursor object that fulfill the conditions in the cond_fn. The cond_fn has a function signature of (str, Any) -> bool:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is a boolean, denoting whether to return the child Cursor object at this path

Raises a CursorFindError if no object or more than one object is found that fulfills the condition of the cond_fn. We raise an error because the user should always expect this method to return the only object whose corresponding key path and value fulfill the condition of the cond_fn.

NOTES:

  • If the cond_fn evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child node that fulfills the condition in cond_fn in a particular key path

  • .find WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The cond_fn will be evaluated recursively, starting at the root node’s children.

Example:

import flax.linen as nn
from flax.cursor import cursor
import jax
import jax.numpy as jnp

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    return x

params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

def cond_fn(path, value):
  '''Find the second dense layer params.'''
  return 'Dense_1' in path

new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1)

for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
  if layer == 'Dense_1':
    assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
  else:
    assert (new_params[layer]['bias'] == params[layer]['bias']).all()

c = cursor(params)
c2 = c.find(cond_fn)
c2['kernel'] += 2
c2['bias'] += 2
new_params = c.build()

for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
  if layer == 'Dense_1':
    assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all()
    assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all()
  else:
    assert (new_params[layer]['kernel'] == params[layer]['kernel']).all()
    assert (new_params[layer]['bias'] == params[layer]['bias']).all()

assert jax.tree_util.tree_all(
      jax.tree_util.tree_map(
          lambda x, y: (x == y).all(),
          params,
          Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
              'params'
          ],
      )
  ) # make sure original params are unchanged
Parameters

cond_fn – the function that will conditionally find child Cursor objects

Returns

A child Cursor object that fulfills the condition in the cond_fn.

find_all(cond_fn)[source]#

Traverse the Cursor object and return a generator of child Cursor objects that fulfill the conditions in the cond_fn. The cond_fn has a function signature of (str, Any) -> bool:

  • The input arguments are the current key path (in the form of a string delimited by '/') and value at that current key path

  • The output is a boolean, denoting whether to return the child Cursor object at this path

NOTES:

  • If the cond_fn evaluates to True at a particular key path, this method will not recurse any further down that branch; i.e. this method will find and return the “earliest” child nodes that fulfill the condition in cond_fn in a particular key path

  • .find_all WILL NOT search the the value at the top-most level of the pytree (i.e. the root node). The cond_fn will be evaluated recursively, starting at the root node’s children.

Example:

import flax.linen as nn
from flax.cursor import cursor
import jax
import jax.numpy as jnp

class Model(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    x = nn.Dense(3)(x)
    x = nn.relu(x)
    return x

params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']

def cond_fn(path, value):
  '''Find all dense layer params.'''
  return 'Dense' in path

c = cursor(params)
for dense_params in c.find_all(cond_fn):
  dense_params['bias'] += 1
new_params = c.build()

for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
  assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()

assert jax.tree_util.tree_all(
      jax.tree_util.tree_map(
          lambda x, y: (x == y).all(),
          params,
          Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
              'params'
          ],
      )
  ) # make sure original params are unchanged
Parameters

cond_fn – the function that will conditionally find child Cursor objects

Returns

A generator of child Cursor objects that fulfill the condition in the cond_fn.

set(value)[source]#

Set a new value for an attribute, property, element or entry in the Cursor object and return a copy of the original object, containing the new set value.

Example:

from flax.cursor import cursor
from flax.training import train_state
import optax

dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}

state = train_state.TrainState.create(
    apply_fn=lambda x: x,
    params=dict_obj,
    tx=optax.adam(1e-3),
)
modified_state = cursor(state).params['b'][1].set(10)
assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
Parameters

value – the value used to set an attribute, property, element or entry in the Cursor object

Returns

A copy of the original object with the new set value.