flax.linen.initializers.ones_init#

flax.linen.initializers.ones_init()[source]#

Builds an initializer that returns a constant array full of ones.

>>> import jax, jax.numpy as jnp
>>> from flax.linen.initializers import ones_init
>>> ones_initializer = ones_init()
>>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)