Initializers#
Initializers for Flax.
- flax.linen.initializers.constant(value, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer that returns arrays full of a constant
value
.- Parameters
value – the constant value with which to fill the initializer.
dtype – optional; the initializer’s default dtype.
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.constant(-7) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32)
- flax.linen.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer for delta orthogonal kernels.
- Parameters
scale – the upper bound of the uniform distribution.
column_axis – the axis that contains the columns that should be orthogonal.
dtype – the default dtype of the weights.
- Returns
A delta orthogonal initializer. The shape passed to the initializer must be 3D, 4D, or 5D.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], [[ 0.27858758, -0.7949833 , -0.53887904], [ 0.9120717 , 0.04322892, 0.40774566], [-0.30085585, -0.6050892 , 0.73712474]], [[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32)
- flax.linen.initializers.glorot_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Glorot normal initializer (aka Xavier normal initializer).
A Glorot normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_avg"
, anddistribution="truncated_normal"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.linen.initializers.glorot_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A Glorot uniform initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_avg"
, anddistribution="uniform"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.linen.initializers.he_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a He normal initializer (aka Kaiming normal initializer).
A He normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 2.0
,mode="fan_in"
, anddistribution="truncated_normal"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.linen.initializers.he_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a He uniform initializer (aka Kaiming uniform initializer).
A He uniform initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 2.0
,mode="fan_in"
, anddistribution="uniform"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.linen.initializers.kaiming_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a He normal initializer (aka Kaiming normal initializer).
A He normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 2.0
,mode="fan_in"
, anddistribution="truncated_normal"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.6604483 , 1.1900088 , 1.2047218 ], [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
- flax.linen.initializers.kaiming_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a He uniform initializer (aka Kaiming uniform initializer).
A He uniform initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 2.0
,mode="fan_in"
, anddistribution="uniform"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.he_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.79611576, 1.2789248 , 1.2896855 ], [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
- flax.linen.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Lecun normal initializer.
A Lecun normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_in"
, anddistribution="truncated_normal"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.46700746, 0.8414632 , 0.8518669 ], [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
- flax.linen.initializers.lecun_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Lecun uniform initializer.
A Lecun uniform initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_in"
, anddistribution="uniform"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.lecun_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.56293887, 0.90433645, 0.9119454 ], [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
- flax.linen.initializers.normal(stddev=0.01, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer that returns real normally-distributed random arrays.
- Parameters
stddev – optional; the standard deviation of the distribution.
dtype – optional; the initializer’s default dtype.
- Returns
An initializer that returns arrays whose values are normally distributed with mean
0
and standard deviationstddev
.
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.normal(5.0) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
- flax.linen.initializers.ones(key, shape, dtype=<class 'jax.numpy.float64'>)#
An initializer that returns a constant array full of ones.
The
key
argument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.linen.initializers.ones_init()[source]#
Builds an initializer that returns a constant array full of ones.
>>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)
- flax.linen.initializers.orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer that returns uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.
- Parameters
scale – the upper bound of the uniform distribution.
column_axis – the axis that contains the columns that should be orthogonal.
dtype – the default dtype of the weights.
- Returns
An orthogonal initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.orthogonal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
- flax.linen.initializers.uniform(scale=0.01, dtype=<class 'jax.numpy.float64'>)#
Builds an initializer that returns real uniformly-distributed random arrays.
- Parameters
scale – optional; the upper bound of the random distribution.
dtype – optional; the initializer’s default dtype.
- Returns
An initializer that returns arrays whose values are uniformly distributed in the range
[0, scale)
.
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.uniform(10.0) >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32)
- flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"
ordistribution="normal"
, samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:the number of input units in the weights tensor, if
mode="fan_in"
,the number of output units, if
mode="fan_out"
, orthe average of the numbers of input and output units, if
mode="fan_avg"
.
This initializer can be configured with
in_axis
,out_axis
, andbatch_axis
to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal"
, the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform"
, samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.
- Parameters
scale – scaling factor (positive float).
mode – one of
"fan_in"
,"fan_out"
, and"fan_avg"
.distribution – random distribution to use. One of
"truncated_normal"
,"normal"
and"uniform"
.in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- flax.linen.initializers.xavier_normal(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Glorot normal initializer (aka Xavier normal initializer).
A Glorot normal initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_avg"
, anddistribution="truncated_normal"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_normal() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.41770416, 0.75262755, 0.7619329 ], [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
- flax.linen.initializers.xavier_uniform(in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A Glorot uniform initializer is a specialization of
jax.nn.initializers.variance_scaling()
wherescale = 1.0
,mode="fan_avg"
, anddistribution="uniform"
.- Parameters
in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.
- Returns
An initializer.
Example:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.glorot_uniform() >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[ 0.50350785, 0.8088631 , 0.81566876], [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
- flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#
An initializer that returns a constant array full of zeros.
The
key
argument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- flax.linen.initializers.zeros_init()[source]#
Builds an initializer that returns a constant array full of zeros.
>>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import zeros_init >>> zeros_initializer = zeros_init() >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
Summary
|
Builds an initializer that returns arrays full of a constant |
|
Builds an initializer for delta orthogonal kernels. |
|
Builds a Glorot normal initializer (aka Xavier normal initializer). |
|
Builds a Glorot uniform initializer (aka Xavier uniform initializer). |
|
Builds a He normal initializer (aka Kaiming normal initializer). |
|
Builds a He uniform initializer (aka Kaiming uniform initializer). |
|
Builds a He normal initializer (aka Kaiming normal initializer). |
|
Builds a He uniform initializer (aka Kaiming uniform initializer). |
|
Builds a Lecun normal initializer. |
|
Builds a Lecun uniform initializer. |
|
Builds an initializer that returns real normally-distributed random arrays. |
|
An initializer that returns a constant array full of ones. |
Builds an initializer that returns a constant array full of ones. |
|
|
Builds an initializer that returns uniformly distributed orthogonal matrices. |
|
Builds an initializer that returns real uniformly-distributed random arrays. |
|
Initializer that adapts its scale to the shape of the weights tensor. |
|
Builds a Glorot normal initializer (aka Xavier normal initializer). |
|
Builds a Glorot uniform initializer (aka Xavier uniform initializer). |
|
An initializer that returns a constant array full of zeros. |
Builds an initializer that returns a constant array full of zeros. |