Stochastic

Contents

Stochastic#

class flax.nnx.Dropout(self, /, *args, **kwargs)[source]#

Create a dropout layer.

To use dropout, call the train() method (or pass in deterministic=False in the constructor or during call time).

To disable dropout, call the eval() method (or pass in deterministic=True in the constructor or during call time).

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class MLP(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     x = self.dropout(x)
...     return x

>>> model = MLP(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 3))

>>> model.train() # use dropout
>>> model(x)
Array([[ 0.       ,  0.       , -1.592019 , -2.5238838]], dtype=float32)

>>> model.eval() # don't use dropout
>>> model(x)
Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32)
Parameters
  • rate – the dropout probability. (_not_ the keep rate!)

  • broadcast_dims – dimensions that will share the same dropout mask

  • deterministic – if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

  • rng_collection – the rng collection name to use when requesting an rng key.

  • rngs – rng key.