Initializers#

flax.nnx.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns arrays full of a constant value.

Parameters
  • value – the constant value with which to fill the initializer.

  • dtype – optional; the initializer’s default dtype.

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)
flax.nnx.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer for delta orthogonal kernels.

Parameters
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns

A delta orthogonal initializer. The shape passed to the initializer must be 3D, 4D, or 5D.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32)  
Array([[[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 0.27858758, -0.7949833 , -0.53887904],
        [ 0.9120717 ,  0.04322892,  0.40774566],
        [-0.30085585, -0.6050892 ,  0.73712474]],

       [[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
flax.nnx.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.nnx.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.nnx.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.nnx.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.nnx.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.nnx.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.nnx.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Lecun normal initializer.

A Lecun normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],
       [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)
flax.nnx.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Lecun uniform initializer.

A Lecun uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.56293887,  0.90433645,  0.9119454 ],
       [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)
flax.nnx.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns real normally-distributed random arrays.

Parameters
  • stddev – optional; the standard deviation of the distribution.

  • dtype – optional; the initializer’s default dtype.

Returns

An initializer that returns arrays whose values are normally distributed with mean 0 and standard deviation stddev.

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.0613258 ,  5.6129413 ,  5.6866574 ],
       [-4.063663  , -4.4520254 ,  0.63115686]], dtype=float32)
flax.nnx.initializers.truncated_normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>, lower=-2.0, upper=2.0)#

Builds an initializer that returns truncated-normal random arrays.

Parameters
  • stddev – optional; the standard deviation of the untruncated distribution. Note that this function does not apply the stddev correction as is done in the variancescaling initializers, and users are expected to apply this correction themselves via the stddev arg if they wish to employ it.

  • dtype – optional; the initializer’s default dtype.

  • lower – Float representing the lower bound for truncation. Applied before the output is multiplied by the stddev.

  • upper – Float representing the upper bound for truncation. Applied before the output is multiplied by the stddev.

Returns

An initializer that returns arrays whose values follow the truncated normal distribution with mean 0 and standard deviation stddev, and range \(\rm{lower * stddev} < x < \rm{upper * stddev}\).

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.truncated_normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 2.9047365,  5.2338114,  5.29852  ],
       [-3.836303 , -4.192359 ,  0.6022964]], dtype=float32)
flax.nnx.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.nnx.initializers.ones_init()#

Builds an initializer that returns a constant array full of ones.

>>> import jax, jax.numpy as jnp
>>> from flax.nnx import initializers
>>> ones_initializer = initializers.ones_init()
>>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.nnx.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns

An orthogonal initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)
flax.nnx.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns real uniformly-distributed random arrays.

Parameters
  • scale – optional; the upper bound of the random distribution.

  • dtype – optional; the initializer’s default dtype.

Returns

An initializer that returns arrays whose values are uniformly distributed in the range [0, scale).

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[7.298188 , 8.691938 , 8.7230015],
       [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
flax.nnx.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Initializer that adapts its scale to the shape of the weights tensor.

With distribution="truncated_normal" or distribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:

  • the number of input units in the weights tensor, if mode="fan_in",

  • the number of output units, if mode="fan_out", or

  • the average of the numbers of input and output units, if mode="fan_avg".

This initializer can be configured with in_axis, out_axis, and batch_axis to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the β€œreceptive field” (convolution kernel spatial axes).

With distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.

With distribution="uniform", samples are drawn from:

  • a uniform interval, if dtype is real, or

  • a uniform disk, if dtype is complex,

with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.

Parameters
  • scale – scaling factor (positive float).

  • mode – one of "fan_in", "fan_out", and "fan_avg".

  • distribution – random distribution to use. One of "truncated_normal", "normal" and "uniform".

  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

flax.nnx.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.nnx.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.key(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.nnx.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
flax.nnx.initializers.zeros_init()#

Builds an initializer that returns a constant array full of zeros.

>>> import jax, jax.numpy as jnp
>>> from flax.nnx import initializers
>>> zeros_initializer = initializers.zeros_init()
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)