flax.linen.get_partition_spec#

flax.linen.get_partition_spec(tree)[source]#

Extracts a PartitionSpec tree from a PyTree containing Partitioned values.