flax.linen.jit

flax.linen.jit(target, variables=True, rngs=True, static_argnums=(), donate_argnums=(), device=None, backend=None, methods=None)[source]

Lifted version of jax.jit.

Parameters
  • target (flax.linen.transforms.Target) – a Module or a function taking a Module as its first argument.

  • variables (Union[bool, str, Collection[str], DenyList]) – The variable collections that are lifted. By default all collections are lifted.

  • rngs (Union[bool, str, Collection[str], DenyList]) – The PRNG sequences that are lifted. By default all PRNG sequences are lifted.

  • static_argnums (Union[int, Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().

  • donate_argnums (Union[int, Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.

  • device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend (Optional[str]) – a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • methods – If target is a Module, the methods of Module to jit.

Returns

A wrapped version of target, set up for just-in-time compilation.

Return type

flax.linen.transforms.Target