Initializers#

Initializers for Flax.

flax.linen.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns arrays full of a constant value.

Parameters
  • value – the constant value with which to fill the initializer.

  • dtype – optional; the initializer’s default dtype.

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.constant(-7)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[-7., -7., -7.],
       [-7., -7., -7.]], dtype=float32)
flax.linen.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer for delta orthogonal kernels.

Parameters
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns

A delta orthogonal initializer. The shape passed to the initializer must be 3D, 4D, or 5D.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32)  
Array([[[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 0.27858758, -0.7949833 , -0.53887904],
        [ 0.9120717 ,  0.04322892,  0.40774566],
        [-0.30085585, -0.6050892 ,  0.73712474]],

       [[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)
flax.linen.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He normal initializer (aka Kaiming normal initializer).

A He normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.6604483 ,  1.1900088 ,  1.2047218 ],
       [-0.87225807, -0.95321447,  0.1369438 ]], dtype=float32)
flax.linen.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a He uniform initializer (aka Kaiming uniform initializer).

A He uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 2.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.he_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.79611576,  1.2789248 ,  1.2896855 ],
       [-1.0108745 , -1.0855657 ,  0.17398663]], dtype=float32)
flax.linen.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Lecun normal initializer.

A Lecun normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_in", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.46700746,  0.8414632 ,  0.8518669 ],
       [-0.61677957, -0.67402434,  0.09683388]], dtype=float32)
flax.linen.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Lecun uniform initializer.

A Lecun uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_in", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.56293887,  0.90433645,  0.9119454 ],
       [-0.71479625, -0.7676109 ,  0.12302713]], dtype=float32)
flax.linen.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns real normally-distributed random arrays.

Parameters
  • stddev – optional; the standard deviation of the distribution.

  • dtype – optional; the initializer’s default dtype.

Returns

An initializer that returns arrays whose values are normally distributed with mean 0 and standard deviation stddev.

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 3.0613258 ,  5.6129413 ,  5.6866574 ],
       [-4.063663  , -4.4520254 ,  0.63115686]], dtype=float32)
flax.linen.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.ones_init()[source]#

Builds an initializer that returns a constant array full of ones.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns uniformly distributed orthogonal matrices.

If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.

Parameters
  • scale – the upper bound of the uniform distribution.

  • column_axis – the axis that contains the columns that should be orthogonal.

  • dtype – the default dtype of the weights.

Returns

An orthogonal initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.orthogonal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 3.9026976e-01,  7.2495741e-01, -5.6756169e-01],
       [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]],            dtype=float32)
flax.linen.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#

Builds an initializer that returns real uniformly-distributed random arrays.

Parameters
  • scale – optional; the upper bound of the random distribution.

  • dtype – optional; the initializer’s default dtype.

Returns

An initializer that returns arrays whose values are uniformly distributed in the range [0, scale).

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[7.298188 , 8.691938 , 8.7230015],
       [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Initializer that adapts its scale to the shape of the weights tensor.

With distribution="truncated_normal" or distribution="normal", samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:

  • the number of input units in the weights tensor, if mode="fan_in",

  • the number of output units, if mode="fan_out", or

  • the average of the numbers of input and output units, if mode="fan_avg".

This initializer can be configured with in_axis, out_axis, and batch_axis to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).

With distribution="truncated_normal", the absolute values of the samples are truncated at 2 standard deviations before scaling.

With distribution="uniform", samples are drawn from:

  • a uniform interval, if dtype is real, or

  • a uniform disk, if dtype is complex,

with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.

Parameters
  • scale – scaling factor (positive float).

  • mode – one of "fan_in", "fan_out", and "fan_avg".

  • distribution – random distribution to use. One of "truncated_normal", "normal" and "uniform".

  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

flax.linen.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot normal initializer (aka Xavier normal initializer).

A Glorot normal initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="truncated_normal".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.41770416,  0.75262755,  0.7619329 ],
       [-0.5516644 , -0.6028657 ,  0.08661086]], dtype=float32)
flax.linen.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

A Glorot uniform initializer is a specialization of jax.nn.initializers.variance_scaling() where scale = 1.0, mode="fan_avg", and distribution="uniform".

Parameters
  • in_axis – axis or sequence of axes of the input dimension in the weights array.

  • out_axis – axis or sequence of axes of the output dimension in the weights array.

  • batch_axis – axis or sequence of axes in the weight array that should be ignored.

  • dtype – the dtype of the weights.

Returns

An initializer.

Example:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)  
Array([[ 0.50350785,  0.8088631 ,  0.81566876],
       [-0.6393332 , -0.6865721 ,  0.11003882]], dtype=float32)
flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
flax.linen.initializers.zeros_init()[source]#

Builds an initializer that returns a constant array full of zeros.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

Summary

constant(value[, dtype])

Builds an initializer that returns arrays full of a constant value.

delta_orthogonal([scale, column_axis, dtype])

Builds an initializer for delta orthogonal kernels.

glorot_normal([in_axis, out_axis, ...])

Builds a Glorot normal initializer (aka Xavier normal initializer).

glorot_uniform([in_axis, out_axis, ...])

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

he_normal([in_axis, out_axis, batch_axis, dtype])

Builds a He normal initializer (aka Kaiming normal initializer).

he_uniform([in_axis, out_axis, batch_axis, ...])

Builds a He uniform initializer (aka Kaiming uniform initializer).

kaiming_normal([in_axis, out_axis, ...])

Builds a He normal initializer (aka Kaiming normal initializer).

kaiming_uniform([in_axis, out_axis, ...])

Builds a He uniform initializer (aka Kaiming uniform initializer).

lecun_normal([in_axis, out_axis, ...])

Builds a Lecun normal initializer.

lecun_uniform([in_axis, out_axis, ...])

Builds a Lecun uniform initializer.

normal([stddev, dtype])

Builds an initializer that returns real normally-distributed random arrays.

ones(key, shape[, dtype])

An initializer that returns a constant array full of ones.

ones_init()

Builds an initializer that returns a constant array full of ones.

orthogonal([scale, column_axis, dtype])

Builds an initializer that returns uniformly distributed orthogonal matrices.

uniform([scale, dtype])

Builds an initializer that returns real uniformly-distributed random arrays.

variance_scaling(scale, mode, distribution)

Initializer that adapts its scale to the shape of the weights tensor.

xavier_normal([in_axis, out_axis, ...])

Builds a Glorot normal initializer (aka Xavier normal initializer).

xavier_uniform([in_axis, out_axis, ...])

Builds a Glorot uniform initializer (aka Xavier uniform initializer).

zeros(key, shape[, dtype])

An initializer that returns a constant array full of zeros.

zeros_init()

Builds an initializer that returns a constant array full of zeros.