Stochastic#

class flax.experimental.nnx.Dropout(*args, **kwargs)[source]#

Create a dropout layer.

rate#

the dropout probability. (_not_ the keep rate!)

Type

float

broadcast_dims#

dimensions that will share the same dropout mask

Type

Sequence[int]

deterministic#

if false the inputs are scaled by 1 / (1 - rate) and masked, whereas if true, no mask is applied and the inputs are returned as is.

Type

bool

rng_collection#

the rng collection name to use when requesting an rng key.

Type

str