Stochastic#

class flax.nnx.Dropout(*args, **kwargs)[source]#

Create a dropout layer.

To use dropout, call the train() method (or pass in deterministic=False in the constructor or during call time).

To disable dropout, call the eval() method (or pass in deterministic=True in the constructor or during call time).

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[-0.9353421,  0.       ,  1.434417 ,  0.       ]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[-0.46767104, -0.7213411 ,  0.7172085 , -0.31562346]], dtype=float32)
rate#

the dropout probability. (_not_ the keep rate!)

Type

float

broadcast_dims#

dimensions that will share the same dropout mask

Type

collections.abc.Sequence[int]

deterministic#

if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

Type

bool

rng_collection#

the rng collection name to use when requesting an rng key.

Type

str

rngs#

rng key.

Type

flax.nnx.rnglib.Rngs | None