LoRA#

NNX LoRA classes.

class flax.nnx.LoRA(self, in_features, lora_rank, out_features, *, base_module=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs)[source]#

A standalone LoRA layer.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0))
>>> layer.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> # Wrap around existing layer
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1))
>>> assert wrapper.base_module == linear
>>> wrapper.lora_a.value.shape
(3, 2)
>>> layer.lora_b.value.shape
(2, 4)
>>> y = layer(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
Parameters
  • in_features – the number of input features.

  • lora_rank – the rank of the LoRA dimension.

  • out_features – the number of output features.

  • base_module – a base module to call and substitute, if possible.

  • dtype – the dtype of the computation (default: infer from input and params).

  • param_dtype – the dtype passed to parameter initializers (default: float32).

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • a_initializer – initializer function for the fan-in matrices. Default to he_uniform.

  • b_initializer – initializer function for the fan-out matrices. Default to zero initializer.

  • lora_param_type – the type of the LoRA params.

__call__(x)[source]#

Call self as a function.

Methods

class flax.nnx.LoRALinear(self, in_features, out_features, *, lora_rank, lora_dtype=None, lora_param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs, **kwargs)[source]#

An nnx.Linear layer in which the output will be LoRAified.

The model state structure will be compatible with that of Linear.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0))
>>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0))
>>> linear.kernel.value.shape
(3, 4)
>>> lora_linear.kernel.value.shape
(3, 4)
>>> lora_linear.lora.lora_a.value.shape
(3, 2)
>>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value)
Array(True, dtype=bool)
>>> y = lora_linear(jnp.ones((16, 3)))
>>> y.shape
(16, 4)
Parameters
  • in_features – the number of input features.

  • out_features – the number of output features.

  • lora_rank – the rank of the LoRA dimension.

  • base_module – a base module to call and substitute, if possible.

  • dtype – the dtype of the computation (default: infer from input and params).

  • param_dtype – the dtype passed to parameter initializers (default: float32).

  • precision – numerical precision of the computation see jax.lax.Precision for details.

  • a_initializer – initializer function for the fan-in matrices. Default to he_uniform.

  • b_initializer – initializer function for the fan-out matrices. Default to zero initializer.

  • lora_param_type – the type of the LoRA params.

__call__(x)[source]#

Applies a linear transformation to the inputs along the last dimension.

Parameters

inputs – The nd-array to be transformed.

Returns

The transformed input.

Methods