flax.serialization package#

Serialization utilities for Jax.

All Flax classes that carry state (e.g., Optimizer) can be turned into a state dict of numpy arrays for easy serialization.

State dicts#

flax.serialization.from_state_dict(target, state, name='.')[source]#

Restores the state of the given target using a state dict.

This function takes the current target as an argument. This lets us know the exact structure of the target, as well as lets us add assertions that shapes and dtypes don’t change.

In practice, none of the leaf values in target are actually used. Only the tree structure, shapes and dtypes.

Parameters
  • target – the object of which the state should be restored.

  • state – a dictionary generated by to_state_dict with the desired new state for target.

  • name – name of branch taken, used to improve deserialization error messages.

Returns

A copy of the object with the restored state.

flax.serialization.to_state_dict(target)[source]#

Returns a dictionary with the state of the given target.

flax.serialization.register_serialization_state(ty, ty_to_state_dict, ty_from_state_dict, override=False)[source]#

Register a type for serialization.

Parameters
  • ty – the type to be registered

  • ty_to_state_dict – a function that takes an instance of ty and returns its state as a dictionary.

  • ty_from_state_dict – a function that takes an instance of ty and a state dict, and returns a copy of the instance with the restored state.

  • override – override a previously registered serialization handler (default: False).

Serialization with MessagePack#

flax.serialization.msgpack_serialize(pytree, in_place=False)[source]#

Save data structure to bytes in msgpack format.

Low-level function that only supports python trees with array leaves, for custom objects use to_bytes. It splits arrays above MAX_CHUNK_SIZE into multiple chunks.

Parameters
  • pytree – python tree of dict, list, tuple with python primitives and array leaves.

  • in_place – boolean specifying if pytree should be modified in place.

Returns

msgpack-encoded bytes of pytree.

flax.serialization.msgpack_restore(encoded_pytree)[source]#

Restore data structure from bytes in msgpack format.

Low-level function that only supports python trees with array leaves, for custom objects use from_bytes.

Parameters

encoded_pytree – msgpack-encoded bytes of python tree.

Returns

Python tree of dict, list, tuple with python primitive and array leaves.

flax.serialization.to_bytes(target)[source]#

Save optimizer or other object as msgpack-serialized state-dict.

Parameters

target – template object with state-dict registrations to be serialized to msgpack format. Typically a flax model or optimizer.

Returns

Bytes of msgpack-encoded state-dict of target object.

flax.serialization.from_bytes(target, encoded_bytes)[source]#

Restore optimizer or other object from msgpack-serialized state-dict.

Parameters
  • target – template object with state-dict registrations that matches the structure being deserialized from encoded_bytes.

  • encoded_bytes – msgpack serialized object structurally isomorphic to target. Typically a flax model or optimizer.

Returns

A new object structurally isomorphic to target containing the updated leaf data from saved data.