LoRA#
NNX LoRA classes.
- class flax.nnx.LoRA(self, in_features, lora_rank, out_features, *, base_module=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs)[source]#
A standalone LoRA layer.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4)
- Parameters
in_features – the number of input features.
lora_rank – the rank of the LoRA dimension.
out_features – the number of output features.
base_module – a base module to call and substitute, if possible.
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see jax.lax.Precision for details.
a_initializer – initializer function for the fan-in matrices. Default to he_uniform.
b_initializer – initializer function for the fan-out matrices. Default to zero initializer.
lora_param_type – the type of the LoRA params.
Methods
- class flax.nnx.LoRALinear(self, in_features, out_features, *, lora_rank, lora_dtype=None, lora_param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs, **kwargs)[source]#
An nnx.Linear layer in which the output will be LoRAified.
The model state structure will be compatible with that of Linear.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.value.shape (3, 4) >>> lora_linear.kernel.value.shape (3, 4) >>> lora_linear.lora.lora_a.value.shape (3, 2) >>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4)
- Parameters
in_features – the number of input features.
out_features – the number of output features.
lora_rank – the rank of the LoRA dimension.
base_module – a base module to call and substitute, if possible.
dtype – the dtype of the computation (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
precision – numerical precision of the computation see jax.lax.Precision for details.
a_initializer – initializer function for the fan-in matrices. Default to he_uniform.
b_initializer – initializer function for the fan-out matrices. Default to zero initializer.
lora_param_type – the type of the LoRA params.
- __call__(x)[source]#
Applies a linear transformation to the inputs along the last dimension.
- Parameters
inputs – The nd-array to be transformed.
- Returns
The transformed input.
Methods