flax.error package#

Flax has the following classes of errors.

exception flax.errors.AlreadyExistsError(path)[source]#

Attempting to overwrite a file via copy. You can pass overwrite=True to disable this behavior and overwrite existing files in.

exception flax.errors.ApplyModuleInvalidMethodError(method)[source]#

When calling Module.apply(), you can specify the method to apply using parameter method. This error is thrown if the provided parameter is not a method in the Module and not a function with at least one argument.

Learn more on the reference docs for Module.apply().

exception flax.errors.ApplyScopeInvalidVariablesStructureError(variables)[source]#

This error is thrown when the dict passed as variables to apply() has an extra ‘params’ layer, i.e. {‘params’: {‘params’: …}}. For more explanation on variable dicts, please see flax.core.variables.

exception flax.errors.ApplyScopeInvalidVariablesTypeError[source]#

When calling Module.apply(), the first argument should be a variable dict. For more explanation on variable dicts, please see flax.core.variables.

exception flax.errors.AssignSubModuleError(cls)[source]#

You are only allowed to create submodules in two places:

  1. If your Module is noncompact: inside Module.setup().

  2. If your Module is compact: inside the method wrapped in nn.compact().

For instance, the following code throws this error, because nn.Conv is created in __call__, which is not marked as compact:

class Foo(nn.Module):
  def setup(self):

  def __call__(self, x):
    conv = nn.Conv(features=3, kernel_size=3)

Foo().init(random.PRNGKey(0), jnp.zeros((1,)))

Note that this error is also thrown if you partially defined a Module inside setup:

class Foo(nn.Module):
  def setup(self):
    self.conv = functools.partial(nn.Conv, features=3)

  def __call__(self, x):
    x = self.conv(kernel_size=4)(x)
    return x

Foo().init(random.PRNGKey(0), jnp.zeros((1,)))

In this case, self.conv(kernel_size=4) is called from __call__, which is disallowed because it’s neither within setup nor a method wrapped in x``nn.compact``.

exception flax.errors.CallCompactUnboundModuleError[source]#

This error occurs when you are trying to call a Module directly, rather than through Module.apply(). For instance, the error will be raised when trying to run this code:

from flax import linen as nn
import jax.numpy as jnp

test_dense = nn.Dense(10)

Instead, you should pass the variables (parameters and other state) via Module.apply() (or use Module.init() to get initial variables):

from jax import random
variables = test_dense.init(random.PRNGKey(0), jnp.ones((5,5)))

y = test_dense.apply(variables, jnp.ones((5,5)))
exception flax.errors.CallSetupUnboundModuleError[source]#

This error occurs when you are trying to call .setup() directly. For instance, the error will be raised when trying to run this code:

from flax import linen as nn
import jax.numpy as jnp

class MyModule(nn.Module):
  def setup(self):
    self.submodule = MySubModule()

module = MyModule()
module.setup() # <-- ERROR!
submodule = module.submodule

In general you shouldn’t call .setup() yourself, if you need to get access to a field or submodule defined inside setup you can instead create a function to extract it and pass it to nn.apply:

# setup() will be called automatically by `nn.apply`
def get_submodule(module):
  return module.submodule.clone() # avoid leaking the Scope

empty_variables = {} # you can also use the real variables
submodule = nn.apply(get_submodule, module)(empty_variables)
exception flax.errors.IncorrectPostInitOverrideError[source]#

This error occurs when you overrode .__post_init__() without calling super().__post_init__(). For example, the error will be raised when trying to run this code:

from flax import linen as nn
import jax.numpy as jnp
import jax
class A(nn.Module):
  x: float
  def __post_init__(self):
    self.x_square = self.x ** 2
    # super().__post_init__() <-- forgot to add this line
  def __call__(self, input):
    return input + 3

r = A(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
exception flax.errors.InvalidCheckpointError(path, step)[source]#

A checkpoint cannot be stored in a directory that already has a checkpoint at the current or a later step.

You can pass overwrite=True to disable this behavior and overwrite existing checkpoints in the target directory.

exception flax.errors.InvalidFilterError(filter_like)[source]#

A filter should be either a boolean, a string or a container object.

exception flax.errors.InvalidInstanceModuleError[source]#

This error occurs when you are trying to call .init(), .init_with_output(), .apply() or `.bind() on the Module class itself, instead of an instance of the Module class. For example, the error will be raised when trying to run this code:

class B(nn.Module):
  def __call__(self, x):
    return x

k = random.PRNGKey(0)
x = random.uniform(random.PRNGKey(1), (2,))
B.init(k, x)   # B is module class, not B() a module instance
B.apply(vs, x)   # similar issue with apply called on class instead of instance.
exception flax.errors.InvalidRngError(msg)[source]#

All rngs used in a Module should be passed to Module.init() and Module.apply() appropriately. We explain both separately using the following example:

class Bar(nn.Module):
  def __call__(self, x):
    some_param = self.param('some_param', nn.initializers.zeros, (1, ))
    dropout_rng = self.make_rng('dropout')
    x = nn.Dense(features=4)(x)

class Foo(nn.Module):
  def __call__(self, x):
    x = Bar()(x)

PRNGs for Module.init()

In this example, two rngs are used:

  • params is used for initializing the parameters of the model. This rng is used to initialize the some_params parameter, and for initializing the weights of the Dense Module used in Bar.

  • dropout is used for the dropout rng that is used in Bar.

So, Foo is initialized as follows:

init_rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
variables = Foo().init(init_rngs, init_inputs)

If a Module only requires an rng for params, you can use:

SomeModule().init(rng, ...)  # Shorthand for {'params': rng}

PRNGs for Module.apply()

When applying Foo, only the rng for dropout is needed, because params is only used for initializing the Module parameters:

Foo().apply(variables, inputs, rngs={'dropout': random.PRNGKey(2)})

If a Module only requires an rng for params, you don’t have to provide rngs for apply at all:

SomeModule().apply(variables, inputs)  # rngs=None
exception flax.errors.InvalidScopeError(scope_name)[source]#

A temporary Scope is only valid within the context in which it is created:

with Scope(variables, rngs=rngs).temporary() as root:
  y = fn(root, *args, **kwargs)
  # Here root is valid.
# Here root is invalid.
exception flax.errors.JaxTransformError[source]#

JAX transforms and Flax modules cannot be mixed.

JAX’s functional transformations expect pure function. When you want to use JAX transformations inside Flax models, you should make use of the Flax transformation wrappers (e.g.: flax.linen.vmap, flax.linen.scan, etc.).

exception flax.errors.MPACheckpointingRequiredError(path, step)[source]#

To optimally save and restore a multiprocess array (GDA or jax Array outputted from pjit), use GlobalAsyncCheckpointManager.

You can create an GlobalAsyncCheckpointManager at top-level and pass it as argument:

from jax.experimental.gda_serialization import serialization as gdas
gda_manager = gdas.GlobalAsyncCheckpointManager()
save_checkpoint(..., gda_manager=gda_manager)
exception flax.errors.MPARestoreDataCorruptedError(step, path)[source]#

A multiprocess array stored in Google Cloud Storage doesn’t contain a “commit_success.txt” file, which should be written at the end of the save.

Failure of finding it could indicate a corruption of your saved GDA data.

exception flax.errors.MPARestoreTargetRequiredError(path, step, key=None)[source]#

Provide a valid target when restoring a checkpoint with a multiprocess array.

Multiprocess arrays need a sharding (global meshes and partition specs) to be initialized. Therefore, to restore a checkpoint that contains a multiprocess array, make sure the target you passed contains valid multiprocess arrays at the corresponding tree structure location. If you cannot provide a full valid target, consider allow_partial_mpa_restoration=True.

exception flax.errors.MPARestoreTypeNotMatchError(step, gda_path)[source]#

Make sure the multiprocess array type you use matches your configuration in jax.config.jax_array.

If you turned jax.config.jax_array on, you should use jax.experimental.array.Array everywhere, instead of using GlobalDeviceArray. Otherwise, avoid using jax.experimental.array to restore your checkpoint.

exception flax.errors.ModifyScopeVariableError(col, variable_name, scope_path)[source]#

You cannot update a variable if the collection it belongs to is immutable. When you are applying a Module, you should specify which variable collections are mutable:

class MyModule(nn.Module):
  def __call__(self, x):
    var = self.variable('batch_stats', 'mean', ...)
    var.value = ...

v = MyModule.init(...)
logits = MyModule.apply(v, batch)  # This throws an error.
logits = MyModule.apply(v, batch, mutable=['batch_stats'])  # This works.
exception flax.errors.MultipleMethodsCompactError[source]#

The @compact decorator may only be added to at most one method in a Flax module. In order to resolve this, you can:

  • remove @compact and define submodules and variables using Module.setup().

  • Use two separate modules that both have a unique @compact method.

TODO(marcvanzee): Link to a design note explaining the motivation behind this. There is no need for an equivalent to hk.transparent and it makes submodules much more sane because there is no need to prefix the method names.

exception flax.errors.NameInUseError(key_type, value, module_name)[source]#

This error is raised when trying to create a submodule, param, or variable with an existing name. They are all considered to be in the same namespace.

Sharing Submodules

This is the wrong pattern for sharing submodules:

y = nn.Dense(feature=3, name='bar')(x)
z = nn.Dense(feature=3, name='bar')(x+epsilon)

Instead, modules should be shared by instance:

dense = nn.Dense(feature=3, name='bar')
y = dense(x)
z = dense(x+epsilon)

If submodules are not provided with a name, a unique name will be given to them automatically:

class MyModule(nn.Module):
  def __call__(self, x):
    x = MySubModule()(x)
    x = MySubModule()(x)  # This is fine.
    return x

Parameters and Variables

A parameter name can collide with a submodule or variable, since they are all stored in the same variable dict:

class Foo(nn.Module):
  def __call__(self, x):
    bar = self.param('bar', nn.initializers.zeros, (1, ))
    embed = nn.Embed(num_embeddings=2, features=5, name='bar')  # <-- ERROR!

Variables should also have unique names, even if they have their own collection:

class Foo(nn.Module):
  def __call__(self, inputs):
    _ = self.param('mean', initializers.lecun_normal(), (2, 2))
    _ = self.variable('stats', 'mean', initializers.zeros, (2, 2))
exception flax.errors.ReservedModuleAttributeError(annotations)[source]#

This error is thrown when creating a Module that is using reserved attributes. The following attributes are reserved:

  • parent: The parent Module of this Module.

  • name: The name of this Module.

exception flax.errors.ScopeCollectionNotFound(col_name, var_name, scope_path)[source]#

This error is thrown when trying to access a variable from an empty collection.

There are two common causes: 1. | The collection was not passed to apply correctly.

For example, you might have used module.apply(params, ...) instead
of module.apply({'params': params}, ...).
  1. The collection is empty because the variables need to be initialized.
    In this case, you should have made the collection mutable during
    apply (e.g.: module.apply(variables, ..., mutable=['state']).
exception flax.errors.ScopeParamNotFoundError(param_name, scope_path)[source]#

This error is thrown when trying to access a parameter that does not exist. For instance, in the code below, the initialized embedding name ‘embedding’ does not match the apply name ‘embed’:

class Embed(nn.Module):
  num_embeddings: int
  features: int

  def __call__(self, inputs, embed_name='embedding'):
    inputs = inputs.astype('int32')
    embedding = self.param(embed_name,
                           (self.num_embeddings, self.features))
    return embedding[inputs]

model = Embed(4, 8)
variables = model.init(random.PRNGKey(0), jnp.ones((5, 5, 1)))
_ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed')
exception flax.errors.ScopeParamShapeError(param_name, scope_path, value_shape, init_shape)[source]#

This error is thrown when the shape of an existing parameter is different from the shape of the return value of the init_fn. This can happen when the shape provided during Module.apply() is different from the one used when initializing the module.

For instance, the following code throws this error because the apply shape ((5, 5, 1)) is different from the init shape ((5, 5). As a result, the shape of the kernel during init is (1, 8), and the shape during apply is (5, 8), which results in this error.:

class NoBiasDense(nn.Module):
  features: int = 8

  def __call__(self, x):
    kernel = self.param('kernel',
                        (x.shape[-1], self.features))  # <--- ERROR
    y = lax.dot_general(x, kernel,
                        (((x.ndim - 1,), (0,)), ((), ())))
    return y

variables = NoBiasDense().init(random.PRNGKey(0), jnp.ones((5, 5, 1)))
_ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
exception flax.errors.ScopeVariableNotFoundError(name, col, scope_path)[source]#

This error is thrown when trying to use a variable in a Scope in a collection that is immutable. In order to create this variable, mark the collection as mutable explicitly using the mutable keyword in Module.apply().

exception flax.errors.SetAttributeFrozenModuleError(module_cls, attr_name, attr_val)[source]#

You can only assign Module attributes to self inside Module.setup(). Outside of that method, the Module instance is frozen (i.e., immutable). This behavior is similar to frozen Python dataclasses.

For instance, this error is raised in the following case:

class SomeModule(nn.Module):
  def __call__(self, x, num_features=10):
    self.num_features = num_features  # <-- ERROR!
    x = nn.Dense(self.num_features)(x)
    return x

s = SomeModule().init(random.PRNGKey(0), jnp.ones((5, 5)))

Similarly, the error is raised when trying to modify a submodule’s attributes after constructing it, even if this is done in the setup() method of the parent module:

class Foo(nn.Module):
    def setup(self):
      self.dense = nn.Dense(features=10)
      self.dense.features = 20  # <--- This is not allowed

    def __call__(self, x):
      return self.dense(x)
exception flax.errors.SetAttributeInModuleSetupError[source]#

You are not allowed to modify Module class attributes in Module.setup():

class Foo(nn.Module):
  features: int = 6

  def setup(self):
    self.features = 3  # <-- ERROR

  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule().init(random.PRNGKey(0), jnp.ones((1, )))

Instead, these attributes should be set when initializing the Module:

class Foo(nn.Module):
  features: int = 6

  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule(features=3).init(random.PRNGKey(0), jnp.ones((1, )))

TODO(marcvanzee): Link to a design note explaining why it’s necessary for modules to stay frozen (otherwise we can’t safely clone them, which we use for lifted transformations).

exception flax.errors.TransformTargetError(target)[source]#

Linen transformations must be applied to Modules classes or functions taking a Module instance as the first argument.

This error occurs when passing an invalid target to a linen transform (nn.vmap, nn.scan, etc.). This occurs for example when trying to transform a Module instance:

nn.vmap(nn.Dense(features))(x)  # raises TransformTargetError

You can transform the nn.Dense class directly instead:


Or you can create a function that takes the module instance as the first argument:

class BatchDense(nn.Module):
  def __call__(self, x):
    return nn.vmap(
        lambda mdl, x: mdl(x),
        variable_axes={'params': 0}, split_rngs={'params': True})(nn.Dense(3), x)
exception flax.errors.TransformedMethodReturnValueError(name)[source]#

Transformed Module methods cannot return other Modules or Variables.