flax.linen.with_logical_constraint#

flax.linen.with_logical_constraint(x, logical_axis_resources, rules=None, mesh=None, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#

Version of pjit’s with_sharding_constraint that uses logical axis names.