Variable dictionary


Variable dictionary#

A variable dict is a normal Python dictionary, which is a container for one or more “variable collections”, each of which are nested dictionaries whose leaves are jax.numpy arrays.

The different variable collections share the same nested tree structure.

For example, consider the following variable dictionary:

  "params": {
    "Conv1": { "weight": ..., "bias": ... },
    "BatchNorm1": { "scale": ..., "mean": ... },
    "Conv2": {...}
  "batch_stats": {
    "BatchNorm1": { "moving_mean": ..., "moving_average": ...}

In this case, the "BatchNorm1" key lives in both the "params" and `"batch_stats"" collections. This reflects the fact that the submodule named ""BatchNorm1"" has both trainable parameters (the "params" collection), as well as other non-trainable variables (the "batch_stats" collection)

TODO: Make “variable dict” design note, and link to it from here.

class flax.linen.Variable(scope, collection, name, unbox)[source]#

A Variable object allows mutable access to a variable in a VariableDict.

Variables are identified by a collection (e.g., “batch_stats”) and a name (e.g., “moving_mean”). The value property gives access to the variable’s content and can be assigned to for mutation.