flax.linen.get_sharding#

flax.linen.get_sharding(tree, mesh)[source]#

Extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh.