flax.linen.initializers.orthogonal#
- flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer that returns uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
- Parameters
scale – the upper bound of the uniform distribution.
column_axis – the axis that contains the columns that should be orthogonal.
dtype – the default dtype of the weights.
- Returns
An orthogonal initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)