flax.linen.initializers.zeros#
- flax.linen.initializers.zeros(key, shape, dtype=<class 'jax.numpy.float64'>)#
An initializer that returns a constant array full of zeros.
The
key
argument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)