flax.linen.Dropout#

class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Create a dropout layer.

Note: When using Module.apply(), make sure to include an RNG seed named 'dropout'. Dropout isn’t necessary for variable initialization. Example usage:

>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp

>>> class MLP(nn.Module):
...   @nn.compact
...   def __call__(self, x, train):
...     x = nn.Dense(4)(x)
...     x = nn.Dropout(0.5, deterministic=not train)(x)
...     return x

>>> model = MLP()
>>> x = jnp.ones((1, 3))
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
>>> model.apply(variables, x, train=False) # don't use dropout
Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
Array([[ 0.       , -1.1856356, -1.0369378,  0.       ]], dtype=float32)
rate#

the dropout probability. (_not_ the keep rate!)

Type

float

broadcast_dims#

dimensions that will share the same dropout mask

Type

Sequence[int]

deterministic#

if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

Type

Optional[bool]

rng_collection#

the rng collection name to use when requesting an rng key.

Type

str

__call__(inputs, deterministic=None, rng=None)[source]#

Applies a random dropout mask to the input.

Parameters
  • inputs – the inputs that should be randomly masked.

  • deterministic – if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

  • rng – an optional PRNGKey used as the random key, if not specified, one will be generated using make_rng with the rng_collection name.

Returns

The masked inputs reweighted to preserve mean.

Methods