flax.linen.with_logical_constraint#

flax.linen.with_logical_constraint(x, logical_axis_resources, fallback=RulesFallback.AXIS_IS_UNSHARDED)[source]#

Version of pjit’s with_sharding_constraint that uses logical axis names.