flax.error package

Flax has the following classes of errors.

exception flax.errors.ApplyModuleInvalidMethodError(method)[source]

When calling Module.apply(), you can specify the method to apply using parameter method. This error is thrown if the provided parameter is not a method in the Module and not a function with at least one argument.

Learn more on the reference docs for Module.apply().

exception flax.errors.ApplyScopeInvalidVariablesError[source]

When calling Module.apply(), the first argument should be a variable dict. For more explanation on variable dicts, please see flax.core.variables.

exception flax.errors.AssignSubModuleError(cls)[source]

You are only allowed to create submodules in two places:

  1. If your Module is noncompact: inside Module.setup().

  2. If your Module is compact: inside the method wrapped in nn.compact().

For instance, the following code throws this error, because nn.Conv is created in __call__, which is not marked as compact:

class Foo(nn.Module):
  def setup(self):

  def __call__(self, x):
    conv = nn.Conv(features=3, kernel_size=3)

Foo().init(random.PRNGKey(0), jnp.zeros((1,)))

Note that this error is also thrown if you partially defined a Module inside setup:

class Foo(nn.Module):
  def setup(self):
    self.conv = functools.partial(nn.Conv, features=3)

  def __call__(self, x):
    x = self.conv(kernel_size=4)(x)
    return x

Foo().init(random.PRNGKey(0), jnp.zeros((1,)))

In this case, self.conv(kernel_size=4) is called from __call__, which is disallowed because it’s neither within setup nor a method wrapped in x``nn.compact``.

exception flax.errors.CallCompactUnboundModuleError[source]

This error occurs when you are trying to call a Module directly, rather than through Module.apply(). For instance, the error will be raised when trying to run this code:

from flax import linen as nn
import jax.numpy as jnp

test_dense = nn.Dense(10)

Instead, you should pass the variables (parameters and other state) via Module.apply() (or use Module.init() to get initial variables):

from jax import random
variables = test_dense.init(random.PRNGKey(0), jnp.ones((5,5)))

y = test_dense.apply(variables, jnp.ones((5,5)))
exception flax.errors.InvalidCheckpointError(path, step)[source]

A checkpoint cannot be stored in a directory that already has a checkpoint at the current or a later step.

You can pass overwrite=True to disable this behavior and overwrite existing checkpoints in the target directory.

exception flax.errors.InvalidFilterError(filter_like)[source]

A filter should be either a boolean, a string or a container object.

exception flax.errors.InvalidRngError(msg)[source]

All rngs used in a Module should be passed to Module.init() and Module.apply() appropriately. We explain both separately using the following example:

class Bar(nn.Module):
  def __call__(self, x):
    some_param = self.param('some_param', nn.initializers.zeros, (1, ))
    dropout_rng = self.make_rng('dropout')
    x = nn.Dense(features=4)(x)

class Foo(nn.Module):
  def __call__(self, x):
    x = Bar()(x)

PRNGs for Module.init()

In this example, two rngs are used:

  • params is used for initializing the parameters of the model. This rng is used to initialize the some_params parameter, and for initializing the weights of the Dense Module used in Bar.

  • dropout is used for the dropout rng that is used in Bar.

So, Foo is initialized as follows:

init_rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
variables = Foo().init(init_rngs, init_inputs)

If a Module only requires an rng for params, you can use:

SomeModule().init(rng, ...)  # Shorthand for {'params': rng}

PRNGs for Module.apply()

When applying Foo, only the rng for dropout is needed, because params is only used for initializing the Module parameters:

Foo().apply(variables, inputs, rngs={'dropout': random.PRNGKey(2)})

If a Module only requires an rng for params, you don’t have to provide rngs for apply at all:

SomeModule().apply(variables, inputs)  # rngs=None
exception flax.errors.InvalidScopeError(scope_name)[source]

A temporary Scope is only valid within the context in which it is created:

with Scope(variables, rngs=rngs).temporary() as root:
  y = fn(root, *args, **kwargs)
  # Here root is valid.
# Here root is invalid.
exception flax.errors.JaxTransformError[source]

JAX transforms and Flax modules cannot be mixed.

JAX’s functional transformations expect pure function. When you want to use JAX transformations inside Flax models, you should make use of the Flax transformation wrappers (e.g.: flax.linen.vmap, flax.linen.scan, etc.).

exception flax.errors.ModifyScopeVariableError(col, variable_name, scope_path)[source]

You cannot update a variable if the collection it belongs to is immutable. When you are applying a Module, you should specify which variable collections are mutable:

class MyModule(nn.Module):
  def __call__(self, x):
    var = self.variable('batch_stats', 'mean', ...)
    var.value = ...

v = MyModule.init(...)
logits = MyModule.apply(v, batch)  # This throws an error.
logits = MyModule.apply(v, batch, mutable=['batch_stats'])  # This works.
exception flax.errors.MultipleMethodsCompactError[source]

The @compact decorator may only be added to at most one method in a Flax module. In order to resolve this, you can:

  • remove @compact and define submodules and variables using Module.setup().

  • Use two separate modules that both have a unique @compact method.

TODO(marcvanzee): Link to a design note explaining the motivation behind this. There is no need for an equivalent to hk.transparent and it makes submodules much more sane because there is no need to prefix the method names.

exception flax.errors.NameInUseError(key_type, value, module_name)[source]

This error is raised when trying to create a submodule, param, or variable with an existing name. They are all considered to be in the same namespace.

Sharing Submodules

This is the wrong pattern for sharing submodules:

y = nn.Dense(feature=3, name='bar')(x)
z = nn.Dense(feature=3, name='bar')(x+epsilon)

Instead, modules should be shared by instance:

dense = nn.Dense(feature=3, name='bar')
y = dense(x)
z = dense(x+epsilon)

If submodules are not provided with a name, a unique name will be given to them automatically:

class MyModule(nn.Module):
  def __call__(self, x):
    x = MySubModule()(x)
    x = MySubModule()(x)  # This is fine.
    return x

Parameters and Variables

A parameter name can collide with a submodule or variable, since they are all stored in the same variable dict:

class Foo(nn.Module):
  def __call__(self, x):
    bar = self.param('bar', nn.initializers.zeros, (1, ))
    embed = nn.Embed(num_embeddings=2, features=5, name='bar')  # <-- ERROR!

Variables should also have unique names, even if they have their own collection:

class Foo(nn.Module):
  def __call__(self, inputs):
    _ = self.param('mean', initializers.lecun_normal(), (2, 2))
    _ = self.variable('stats', 'mean', initializers.zeros, (2, 2))
exception flax.errors.ReservedModuleAttributeError(annotations)[source]

This error is thrown when creating a Module that is using reserved attributes. The following attributes are reserved:

  • parent: The parent Module of this Module.

  • name: The name of this Module.

exception flax.errors.ScopeParamNotFoundError(param_name, scope_path)[source]

This error is thrown when trying to access a parameter that does not exist. For instance, in the code below, the initialized embedding name ‘embedding’ does not match the apply name ‘embed’:

class Embed(nn.Module):
  num_embeddings: int
  features: int

  def __call__(self, inputs, embed_name='embedding'):
    inputs = inputs.astype('int32')
    embedding = self.param(embed_name,
                           (self.num_embeddings, self.features))
    return embedding[inputs]

variables = Embed(4, 8).init(random.PRNGKey(0), jnp.ones((5, 5, 1)))
_ = Embed().apply(variables, jnp.ones((5, 5, 1)), 'embed')
exception flax.errors.ScopeParamShapeError(param_name, scope_path, value_shape, init_shape)[source]

This error is thrown when the shape of an existing parameter is different from the shape of the return value of the init_fn. This can happen when the shape provided during Module.apply() is different from the one used when initializing the module.

For instance, the following code throws this error because the apply shape ((5, 5, 1)) is different from the init shape ((5, 5). As a result, the shape of the kernel during init is (1, 8), and the shape during apply is (5, 8), which results in this error.:

class NoBiasDense(nn.Module):
  features: int = 8

  def __call__(self, x):
    kernel = self.param('kernel',
                        (x.shape[-1], self.features))  # <--- ERROR
    y = lax.dot_general(x, kernel,
                        (((x.ndim - 1,), (0,)), ((), ())))
    return y

variables = NoBiasDense().init(random.PRNGKey(0), jnp.ones((5, 5, 1)))
_ = NoBiasDense().apply(variables, jnp.ones((5, 5)))
exception flax.errors.ScopeVariableNotFoundError(name, col, scope_path)[source]

This error is thrown when trying to use a variable in a Scope in a collection that is immutable. In order to create this variable, mark the collection as mutable explicitly using the mutable keyword in Module.apply().

exception flax.errors.SetAttributeFrozenModuleError(module_cls, attr_name, attr_val)[source]

You can only assign Module attributes to self inside Module.setup(). Outside of that method, the Module instance is frozen (i.e., immutable). This behavior is similar to frozen Python dataclasses.

For instance, this error is raised in the following case:

class SomeModule(nn.Module):
  def __call__(self, x, num_features=10):
    self.num_features = num_features  # <-- ERROR!
    x = nn.Dense(self.num_features)(x)
    return x

s = SomeModule().init(random.PRNGKey(0), jnp.ones((5, 5)))

Similarly, the error is raised when trying to modify a submodule’s attributes after constructing it, even if this is done in the setup() method of the parent module:

class Foo(nn.Module):
    def setup(self):
      self.dense = nn.Dense(features=10)
      self.dense.features = 20  # <--- This is not allowed

    def __call__(self, x):
      return self.dense(x)
exception flax.errors.SetAttributeInModuleSetupError[source]

You are not allowed to modify Module class attributes in Module.setup():

class Foo(nn.Module):
  features: int = 6

  def setup(self):
    self.features = 3  # <-- ERROR

  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule().init(random.PRNGKey(0), jnp.ones((1, )))

Instead, these attributes should be set when initializing the Module:

class Foo(nn.Module):
  features: int = 6

  def __call__(self, x):
    return nn.Dense(self.features)(x)

variables = SomeModule(features=3).init(random.PRNGKey(0), jnp.ones((1, )))

TODO(marcvanzee): Link to a design note explaining why it’s necessary for modules to stay frozen (otherwise we can’t safely clone them, which we use for lifted transformations).

exception flax.errors.TransformedMethodReturnValueError(name)[source]

Transformed Module methods cannot return other Modules or Variables.

This commonly occurs when named_call is automatically applied to helper constructor methods when profiling is enabled, and can be mitigated by using the @nn.nowrap decorator to prevent automatic wrapping.