flax.linen.initializers.zeros_init#

flax.linen.initializers.zeros_init()[source]#

Builds an initializer that returns a constant array full of zeros.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import zeros_init
>>> zeros_initializer = zeros_init()
>>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)