# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
from typing import (
Any,
Callable,
Dict,
Generator,
Generic,
Mapping,
Optional,
Protocol,
TypeVar,
runtime_checkable,
)
from flax.core import FrozenDict
from flax.errors import CursorFindError, TraverseTreeError
A = TypeVar('A')
Key = Any
@runtime_checkable
class Indexable(Protocol):
def __getitem__(self, key) -> Any:
...
class AccessType(enum.Enum):
ITEM = enum.auto()
ATTR = enum.auto()
@dataclasses.dataclass
class ParentKey(Generic[A]):
parent: 'Cursor[A]'
key: Key
access_type: AccessType
def is_named_tuple(obj):
return (
isinstance(obj, tuple)
and hasattr(obj, '_fields')
and hasattr(obj, '_asdict')
and hasattr(obj, '_replace')
)
def _traverse_tree(path, obj, *, update_fn=None, cond_fn=None):
"""Helper function for ``Cursor.apply_update`` and ``Cursor.find_all``.
Exactly one of ``update_fn`` and ``cond_fn`` must be not None.
- If ``update_fn`` is not None, then ``Cursor.apply_update`` is calling
this function and ``_traverse_tree`` will return a generator where
each generated element is of type Tuple[Tuple[Union[str, int], AccessType], Any].
The first element is a tuple of the key path and access type where the
change was applied from the ``update_fn``, and the second element is
the newly modified value. If the generator is non-empty, then the
tuple key path will always be non-empty as well.
- If ``cond_fn`` is not None, then ``Cursor.find_all`` is calling this
function and ``_traverse_tree`` will return a generator where each
generated element is of type Tuple[Union[str, int], AccessType]. The
tuple contains the key path and access type where the object was found
that fulfilled the conditions of the ``cond_fn``.
"""
if not (bool(update_fn) ^ bool(cond_fn)):
raise TraverseTreeError(update_fn, cond_fn)
if path:
str_path = '/'.join(str(key) for key, _ in path)
if update_fn:
new_obj = update_fn(str_path, obj)
if new_obj is not obj:
yield path, new_obj
return
elif cond_fn(str_path, obj): # type: ignore
yield path
return
if isinstance(obj, (FrozenDict, dict)):
items = obj.items()
access_type = AccessType.ITEM
elif is_named_tuple(obj):
items = ((name, getattr(obj, name)) for name in obj._fields) # type: ignore
access_type = AccessType.ATTR
elif isinstance(obj, (list, tuple)):
items = enumerate(obj)
access_type = AccessType.ITEM
elif dataclasses.is_dataclass(obj):
items = (
(f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj) if f.init
)
access_type = AccessType.ATTR
else:
return
if update_fn:
for key, value in items:
yield from _traverse_tree(
path + ((key, access_type),), value, update_fn=update_fn
)
else:
for key, value in items:
yield from _traverse_tree(
path + ((key, access_type),), value, cond_fn=cond_fn
)
[docs]class Cursor(Generic[A]):
_obj: A
_parent_key: Optional[ParentKey[A]]
_changes: Dict[Any, 'Cursor[A]']
def __init__(self, obj: A, parent_key: Optional[ParentKey[A]]):
# NOTE: we use `vars` here to avoid calling `__setattr__`
# vars(self) = self.__dict__
vars(self)['_obj'] = obj
vars(self)['_parent_key'] = parent_key
vars(self)['_changes'] = {}
@property
def _root(self) -> 'Cursor[A]':
if self._parent_key is None:
return self
else:
return self._parent_key.parent._root # type: ignore
@property
def _path(self) -> str:
if self._parent_key is None:
return ''
if self._parent_key.access_type == AccessType.ITEM: # type: ignore
if isinstance(self._parent_key.key, str): # type: ignore
key = "'" + self._parent_key.key + "'" # type: ignore
else:
key = str(self._parent_key.key) # type: ignore
return self._parent_key.parent._path + '[' + key + ']' # type: ignore
# self.parent_key.access_type == AccessType.ATTR:
return self._parent_key.parent._path + '.' + self._parent_key.key # type: ignore
def __getitem__(self, key) -> 'Cursor[A]':
if key in self._changes:
return self._changes[key]
if not isinstance(self._obj, Indexable):
raise TypeError(f'Cannot index into {self._obj}')
if isinstance(self._obj, Mapping) and key not in self._obj:
raise KeyError(f'Key {key} not found in {self._obj}')
if is_named_tuple(self._obj):
return getattr(self, self._obj._fields[key]) # type: ignore
child = Cursor(self._obj[key], ParentKey(self, key, AccessType.ITEM))
self._changes[key] = child
return child
def __getattr__(self, name) -> 'Cursor[A]':
if name in self._changes:
return self._changes[name]
if not hasattr(self._obj, name):
raise AttributeError(f'Attribute {name} not found in {self._obj}')
child = Cursor(
getattr(self._obj, name), ParentKey(self, name, AccessType.ATTR)
)
self._changes[name] = child
return child
def __setitem__(self, key, value):
if is_named_tuple(self._obj):
return setattr(self, self._obj._fields[key], value) # type: ignore
self._changes[key] = Cursor(value, ParentKey(self, key, AccessType.ITEM))
def __setattr__(self, name, value):
self._changes[name] = Cursor(value, ParentKey(self, name, AccessType.ATTR))
[docs] def set(self, value) -> A:
"""Set a new value for an attribute, property, element or entry
in the Cursor object and return a copy of the original object,
containing the new set value.
Example::
>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax
>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
>>> state = train_state.TrainState.create(
... apply_fn=lambda x: x,
... params=dict_obj,
... tx=optax.adam(1e-3),
... )
>>> modified_state = cursor(state).params['b'][1].set(10)
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
Args:
value: the value used to set an attribute, property, element or entry in the Cursor object
Returns:
A copy of the original object with the new set value.
"""
if self._parent_key is None:
return value
parent, key = self._parent_key.parent, self._parent_key.key # type: ignore
parent._changes[key] = Cursor(value, self._parent_key)
return parent._root.build()
[docs] def build(self) -> A:
"""Create and return a copy of the original object with accumulated changes.
This method is to be called after making changes to the Cursor object.
.. note::
The new object is built bottom-up, the changes will be first applied
to the leaf nodes, and then its parent, all the way up to the root.
Example::
>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax
>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
>>> state = train_state.TrainState.create(
... apply_fn=lambda x: x,
... params=dict_obj,
... tx=optax.adam(1e-3),
... )
>>> new_fn = lambda x: x + 1
>>> c = cursor(state)
>>> c.params['b'][1] = 10
>>> c.apply_fn = new_fn
>>> modified_state = c.build()
>>> assert modified_state.params == {'a': 1, 'b': (2, 10), 'c': [4, 5]}
>>> assert modified_state.apply_fn == new_fn
Returns:
A copy of the original object with the accumulated changes.
"""
changes = {
key: child.build() if isinstance(child, Cursor) else child
for key, child in self._changes.items()
}
if isinstance(self._obj, FrozenDict):
obj = self._obj.copy(changes) # type: ignore
elif isinstance(self._obj, (dict, list)):
obj = self._obj.copy() # type: ignore
for key, value in changes.items():
obj[key] = value
elif is_named_tuple(self._obj):
obj = self._obj._replace(**changes) # type: ignore
elif isinstance(self._obj, tuple):
obj = list(self._obj) # type: ignore
for key, value in changes.items():
obj[key] = value
obj = tuple(obj) # type: ignore
elif dataclasses.is_dataclass(self._obj):
obj = dataclasses.replace(self._obj, **changes) # type: ignore
else:
obj = self._obj # type: ignore
return obj # type: ignore
[docs] def apply_update(
self,
update_fn: Callable[[str, Any], Any],
) -> 'Cursor[A]':
"""Traverse the Cursor object and record conditional changes recursively via an ``update_fn``.
The changes are recorded in the Cursor object's ``._changes`` dictionary. To generate a copy
of the original object with the accumulated changes, call the ``.build`` method after calling
``.apply_update``.
The ``update_fn`` has a function signature of ``(str, Any) -> Any``:
- The input arguments are the current key path (in the form of a string delimited
by ``'/'``) and value at that current key path
- The output is the new value (either modified by the ``update_fn`` or same as the
input value if the condition wasn't fulfilled)
.. note::
- If the ``update_fn`` returns a modified value, this method will not recurse any further
down that branch to record changes. For example, if we intend to replace an attribute that points
to a dictionary with an int, we don't need to look for further changes inside the dictionary,
since the dictionary will be replaced anyways.
- The ``is`` operator is used to determine whether the return value is modified (by comparing it
to the input value). Therefore if the ``update_fn`` modifies a mutable container (e.g. lists,
dicts, etc.) and returns the same container, ``.apply_update`` will treat the returned value as
unmodified as it contains the same ``id``. To avoid this, return a copy of the modified value.
- ``.apply_update`` WILL NOT call the ``update_fn`` to the value at the top-most level of
the pytree (i.e. the root node). The ``update_fn`` will first be called on the root node's
children, and then the pytree traversal will continue recursively from there.
Example::
>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... return x
>>> params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params']
>>> def update_fn(path, value):
... '''Multiply all dense kernel params by 2 and add 1.
... Subtract the Dense_1 bias param by 1.'''
... if 'kernel' in path:
... return value * 2 + 1
... elif 'Dense_1' in path and 'bias' in path:
... return value - 1
... return value
>>> c = cursor(params)
>>> new_params = c.apply_update(update_fn).build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
... assert (new_params[layer]['kernel'] == 2 * params[layer]['kernel'] + 1).all()
... if layer == 'Dense_1':
... assert (new_params[layer]['bias'] == params[layer]['bias'] - 1).all()
... else:
... assert (new_params[layer]['bias'] == params[layer]['bias']).all()
>>> assert jax.tree_util.tree_all(
... jax.tree_util.tree_map(
... lambda x, y: (x == y).all(),
... params,
... Model().init(jax.random.key(0), jnp.empty((1, 2)))[
... 'params'
... ],
... )
... ) # make sure original params are unchanged
Args:
update_fn: the function that will conditionally record changes to the Cursor object
Returns:
The current Cursor object with the recorded conditional changes specified by the
``update_fn``. To generate a copy of the original object with the accumulated
changes, call the ``.build`` method after calling ``.apply_update``.
"""
for path, value in _traverse_tree((), self._obj, update_fn=update_fn):
child = self
for key, access_type in path[:-1]:
if access_type is AccessType.ITEM:
child = child[key]
else: # access_type is AccessType.ATTR
child = getattr(child, key)
key, access_type = path[-1]
if access_type is AccessType.ITEM:
child[key] = value
else: # access_type is AccessType.ATTR
setattr(child, key, value)
return self
[docs] def find(self, cond_fn: Callable[[str, Any], bool]) -> 'Cursor[A]':
"""Traverse the Cursor object and return a child Cursor object that fulfill the
conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``:
- The input arguments are the current key path (in the form of a string delimited
by ``'/'``) and value at that current key path
- The output is a boolean, denoting whether to return the child Cursor object at this path
Raises a :meth:`CursorFindError <flax.errors.CursorFindError>` if no object or more
than one object is found that fulfills the condition of the ``cond_fn``. We raise an
error because the user should always expect this method to return the only object whose
corresponding key path and value fulfill the condition of the ``cond_fn``.
.. note::
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child node
that fulfills the condition in ``cond_fn`` in a particular key path
- ``.find`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
Example::
>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... return x
>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']
>>> def cond_fn(path, value):
... '''Find the second dense layer params.'''
... return 'Dense_1' in path
>>> new_params = cursor(params).find(cond_fn)['bias'].set(params['Dense_1']['bias'] + 1)
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
... if layer == 'Dense_1':
... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
... else:
... assert (new_params[layer]['bias'] == params[layer]['bias']).all()
>>> c = cursor(params)
>>> c2 = c.find(cond_fn)
>>> c2['kernel'] += 2
>>> c2['bias'] += 2
>>> new_params = c.build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
... if layer == 'Dense_1':
... assert (new_params[layer]['kernel'] == params[layer]['kernel'] + 2).all()
... assert (new_params[layer]['bias'] == params[layer]['bias'] + 2).all()
... else:
... assert (new_params[layer]['kernel'] == params[layer]['kernel']).all()
... assert (new_params[layer]['bias'] == params[layer]['bias']).all()
>>> assert jax.tree_util.tree_all(
... jax.tree_util.tree_map(
... lambda x, y: (x == y).all(),
... params,
... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
... 'params'
... ],
... )
... ) # make sure original params are unchanged
Args:
cond_fn: the function that will conditionally find child Cursor objects
Returns:
A child Cursor object that fulfills the condition in the ``cond_fn``.
"""
generator = self.find_all(cond_fn)
try:
cursor = next(generator)
except StopIteration:
raise CursorFindError()
try:
cursor2 = next(generator)
raise CursorFindError(cursor, cursor2)
except StopIteration:
return cursor
[docs] def find_all(
self, cond_fn: Callable[[str, Any], bool]
) -> Generator['Cursor[A]', None, None]:
"""Traverse the Cursor object and return a generator of child Cursor objects that fulfill the
conditions in the ``cond_fn``. The ``cond_fn`` has a function signature of ``(str, Any) -> bool``:
- The input arguments are the current key path (in the form of a string delimited
by ``'/'``) and value at that current key path
- The output is a boolean, denoting whether to return the child Cursor object at this path
.. note::
- If the ``cond_fn`` evaluates to True at a particular key path, this method will not recurse
any further down that branch; i.e. this method will find and return the "earliest" child nodes
that fulfill the condition in ``cond_fn`` in a particular key path
- ``.find_all`` WILL NOT search the the value at the top-most level of the pytree (i.e. the root
node). The ``cond_fn`` will be evaluated recursively, starting at the root node's children.
Example::
>>> import flax.linen as nn
>>> from flax.cursor import cursor
>>> import jax, jax.numpy as jnp
>>> class Model(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... x = nn.Dense(3)(x)
... x = nn.relu(x)
... return x
>>> params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params']
>>> def cond_fn(path, value):
... '''Find all dense layer params.'''
... return 'Dense' in path
>>> c = cursor(params)
>>> for dense_params in c.find_all(cond_fn):
... dense_params['bias'] += 1
>>> new_params = c.build()
>>> for layer in ('Dense_0', 'Dense_1', 'Dense_2'):
... assert (new_params[layer]['bias'] == params[layer]['bias'] + 1).all()
>>> assert jax.tree_util.tree_all(
... jax.tree_util.tree_map(
... lambda x, y: (x == y).all(),
... params,
... Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[
... 'params'
... ],
... )
... ) # make sure original params are unchanged
Args:
cond_fn: the function that will conditionally find child Cursor objects
Returns:
A generator of child Cursor objects that fulfill the condition in the ``cond_fn``.
"""
for path in _traverse_tree((), self._obj, cond_fn=cond_fn):
child = self
for key, access_type in path:
if access_type is AccessType.ITEM:
child = child[key]
else: # access_type is AccessType.ATTR
child = getattr(child, key)
yield child
def __str__(self):
return str(self._obj)
def __repr__(self):
return self._pretty_repr()
def _pretty_repr(self, indent=2, _prefix_indent=0):
s = 'Cursor(\n'
obj_str = repr(self._obj).replace(
'\n', '\n' + ' ' * (_prefix_indent + indent)
)
s += ' ' * (_prefix_indent + indent) + f'_obj={obj_str},\n'
s += ' ' * (_prefix_indent + indent) + '_changes={'
if self._changes:
s += '\n'
for key in self._changes:
str_key = repr(key)
prefix = ' ' * (_prefix_indent + 2 * indent) + str_key + ': '
s += (
prefix
+ self._changes[key]._pretty_repr(
indent=indent, _prefix_indent=len(prefix)
)
+ ',\n'
)
s = s[
:-2
] # remove comma and newline character for last element in self._changes
s += '\n' + ' ' * (_prefix_indent + indent) + '}\n'
else:
s += '}\n'
s += ' ' * _prefix_indent + ')'
return s
def __len__(self):
return len(self._obj)
def __iter__(self):
if isinstance(self._obj, (tuple, list)):
return (self[i] for i in range(len(self._obj)))
else:
raise NotImplementedError(
'__iter__ method only implemented for tuples and lists, not type'
f' {type(self._obj)}'
)
def __reversed__(self):
if isinstance(self._obj, (tuple, list)):
return (self[i] for i in range(len(self._obj) - 1, -1, -1))
else:
raise NotImplementedError(
'__reversed__ method only implemented for tuples and lists, not type'
f' {type(self._obj)}'
)
def __add__(self, other):
return self._obj + other
def __sub__(self, other):
return self._obj - other
def __mul__(self, other):
return self._obj * other
def __matmul__(self, other):
return self._obj @ other
def __truediv__(self, other):
return self._obj / other
def __floordiv__(self, other):
return self._obj // other
def __mod__(self, other):
return self._obj % other
def __divmod__(self, other):
return divmod(self._obj, other)
def __pow__(self, other):
return pow(self._obj, other)
def __lshift__(self, other):
return self._obj << other
def __rshift__(self, other):
return self._obj >> other
def __and__(self, other):
return self._obj & other
def __xor__(self, other):
return self._obj ^ other
def __or__(self, other):
return self._obj | other
def __radd__(self, other):
return other + self._obj
def __rsub__(self, other):
return other - self._obj
def __rmul__(self, other):
return other * self._obj
def __rmatmul__(self, other):
return other @ self._obj
def __rtruediv__(self, other):
return other / self._obj
def __rfloordiv__(self, other):
return other // self._obj
def __rmod__(self, other):
return other % self._obj
def __rdivmod__(self, other):
return divmod(other, self._obj)
def __rpow__(self, other):
return pow(other, self._obj)
def __rlshift__(self, other):
return other << self._obj
def __rrshift__(self, other):
return other >> self._obj
def __rand__(self, other):
return other & self._obj
def __rxor__(self, other):
return other ^ self._obj
def __ror__(self, other):
return other | self._obj
def __neg__(self):
return -self._obj
def __pos__(self):
return +self._obj
def __abs__(self):
return abs(self._obj)
def __invert__(self):
return ~self._obj
def __round__(self, ndigits=None):
return round(self._obj, ndigits)
def __lt__(self, other):
return self._obj < other
def __le__(self, other):
return self._obj <= other
def __eq__(self, other):
return self._obj == other
def __ne__(self, other):
return self._obj != other
def __gt__(self, other):
return self._obj > other
def __ge__(self, other):
return self._obj >= other
[docs]def cursor(obj: A) -> Cursor[A]:
"""Wrap :class:`Cursor <flax.cursor.Cursor>` over ``obj`` and return it.
Changes can then be applied to the Cursor object in the following ways:
- single-line change via the ``.set`` method
- multiple changes, and then calling the ``.build`` method
- multiple changes conditioned on the pytree path and node value via the
``.apply_update`` method, and then calling the ``.build`` method
``.set`` example::
>>> from flax.cursor import cursor
>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> modified_dict_obj = cursor(dict_obj)['b'][0].set(10)
>>> assert modified_dict_obj == {'a': 1, 'b': (10, 3), 'c': [4, 5]}
``.build`` example::
>>> from flax.cursor import cursor
>>> dict_obj = {'a': 1, 'b': (2, 3), 'c': [4, 5]}
>>> c = cursor(dict_obj)
>>> c['b'][0] = 10
>>> c['a'] = (100, 200)
>>> modified_dict_obj = c.build()
>>> assert modified_dict_obj == {'a': (100, 200), 'b': (10, 3), 'c': [4, 5]}
``.apply_update`` example::
>>> from flax.cursor import cursor
>>> from flax.training import train_state
>>> import optax
>>> def update_fn(path, value):
... '''Replace params with empty dictionary.'''
... if 'params' in path:
... return {}
... return value
>>> state = train_state.TrainState.create(
... apply_fn=lambda x: x,
... params={'a': 1, 'b': 2},
... tx=optax.adam(1e-3),
... )
>>> c = cursor(state)
>>> state2 = c.apply_update(update_fn).build()
>>> assert state2.params == {}
>>> assert state.params == {'a': 1, 'b': 2} # make sure original params are unchanged
If the underlying ``obj`` is a ``list`` or ``tuple``, iterating over the Cursor object
to get the child Cursors is also possible::
>>> from flax.cursor import cursor
>>> c = cursor(((1, 2), (3, 4)))
>>> for child_c in c:
... child_c[1] *= -1
>>> assert c.build() == ((1, -2), (3, -4))
View the docstrings for each method to see more examples of their usage.
Args:
obj: the object you want to wrap the Cursor in
Returns:
A Cursor object wrapped around obj.
"""
return Cursor(obj, None)