spmd#

flax.nnx.get_partition_spec(tree)[source]#

Extracts a PartitionSpec tree from a PyTree containing Variable values.

flax.nnx.get_named_sharding(tree, mesh)[source]#
flax.nnx.with_partitioning(initializer, sharding, mesh=None, get_value_hooks=(), create_value_hooks=(), **metadata)[source]#
flax.nnx.with_sharding_constraint(x, axis_resources, mesh=None)[source]#