Migrate to regular dicts#

Flax will migrate from returning FrozenDicts to regular dicts when calling .init, .init_with_output and .apply Module methods.

The original issue is outlined here.

This guide shows some common upgrade patterns.

Utility functions#

FrozenDicts are immutable dictionaries that implement an additional 4 methods:

To accommodate the regular dict change, replace usage of FrozenDict methods with their utility function equivalent from flax.core.frozen_dict. These utility functions mimic the behavior of their corresponding FrozenDict method, and can be called on either FrozenDicts or regular dicts. The following are the utility functions and example upgrade patterns:

copy#

variables = variables.copy(add_or_replace={'other_variables': other_variables})
variables = flax.core.copy(variables, add_or_replace={'other_variables': other_variables})

pop#

state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')

pretty_repr#

str_repr = variables.pretty_repr()
str_repr = flax.core.pretty_repr(variables)

unfreeze#

variables = variables.unfreeze()
variables = flax.core.unfreeze(variables)

Modifying config values#

A temporary feature flag flax_return_frozendict is set up to help with the migration. To toggle behavior between returning FrozenDict and regular dict variables at runtime, run flax.config.update('flax_return_frozendict', <BOOLEAN_VALUE>) in your code.

For example:

x = jnp.empty((1,3))

flax.config.update('flax_return_frozendict', True) # set Flax to return FrozenDicts
variables = nn.Dense(5).init(jax.random.key(0), x)

assert isinstance(variables, flax.core.FrozenDict)

flax.config.update('flax_return_frozendict', False) # set Flax to return regular dicts
variables = nn.Dense(5).init(jax.random.key(0), x)

assert isinstance(variables, dict)

Alternatively, the environment variable flax_return_frozendict (found here) can be directly modified in the Flax source code.

Migration status#

As of July 19th, 2023, flax_return_frozendict is set to False (see #3193), meaning Flax will default to returning regular dicts from version 0.7.1 onward. This flag can be flipped to True temporarily to have Flax return Frozendicts. However this feature flag will eventually be removed in the future.