flax.linen.Dropout#
- class flax.linen.Dropout(rate, broadcast_dims=(), deterministic=None, rng_collection='dropout', parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Create a dropout layer.
Note: When using
Module.apply()
, make sure to include an RNG seed named ‘dropout’. For example:model.apply({'params': params}, inputs=inputs, train=True, rngs={'dropout': dropout_rng})`
- rate#
the dropout probability. (_not_ the keep rate!)
- Type
float
- broadcast_dims#
dimensions that will share the same dropout mask
- Type
Sequence[int]
- deterministic#
if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.
- Type
Optional[bool]
- rng_collection#
the rng collection name to use when requesting an rng key.
- Type
str
- __call__(inputs, deterministic=None, rng=None)[source]#
Applies a random dropout mask to the input.
- Parameters
inputs – the inputs that should be randomly masked.
deterministic – if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.
rng – an optional PRNGKey used as the random key, if not specified, one will be generated using
make_rng
with therng_collection
name.
- Returns
The masked inputs reweighted to preserve mean.
Methods