flax.linen.initializers.variance_scaling#
- flax.linen.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)#
Initializer that adapts its scale to the shape of the weights tensor.
With
distribution="truncated_normal"
ordistribution="normal"
, samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of \(\sqrt{\frac{scale}{n}}\), where n is:the number of input units in the weights tensor, if
mode="fan_in"
,the number of output units, if
mode="fan_out"
, orthe average of the numbers of input and output units, if
mode="fan_avg"
.
This initializer can be configured with
in_axis
,out_axis
, andbatch_axis
to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).With
distribution="truncated_normal"
, the absolute values of the samples are truncated at 2 standard deviations before scaling.With
distribution="uniform"
, samples are drawn from:a uniform interval, if dtype is real, or
a uniform disk, if dtype is complex,
with a mean of zero and a standard deviation of \(\sqrt{\frac{scale}{n}}\) where n is defined above.
- Parameters
scale – scaling factor (positive float).
mode – one of
"fan_in"
,"fan_out"
, and"fan_avg"
.distribution – random distribution to use. One of
"truncated_normal"
,"normal"
and"uniform"
.in_axis – axis or sequence of axes of the input dimension in the weights array.
out_axis – axis or sequence of axes of the output dimension in the weights array.
batch_axis – axis or sequence of axes in the weight array that should be ignored.
dtype – the dtype of the weights.