flax.linen.get_sharding# flax.linen.get_sharding(tree, mesh)[source]# Extracts a jax.sharding tree from a PyTree containing Partitioned values and a mesh.