flax.linen.initializers.ones_init#
- flax.linen.initializers.ones_init()[source]#
Builds an initializer that returns a constant array full of ones.
>>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32)