Stochastic#
- class flax.nnx.Dropout(self, /, *args, **kwargs)[source]#
Create a dropout layer.
To use dropout, call the
train()
method (or pass indeterministic=False
in the constructor or during call time).To disable dropout, call the
eval()
method (or pass indeterministic=True
in the constructor or during call time).Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class MLP(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(in_features=3, out_features=4, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... x = self.linear(x) ... x = self.dropout(x) ... return x >>> model = MLP(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 3)) >>> model.train() # use dropout >>> model(x) Array([[ 0. , 0. , -1.592019 , -2.5238838]], dtype=float32) >>> model.eval() # don't use dropout >>> model(x) Array([[ 1.0533503, -1.2679932, -0.7960095, -1.2619419]], dtype=float32)
- Parameters
rate – the dropout probability. (_not_ the keep rate!)
broadcast_dims – dimensions that will share the same dropout mask
deterministic – if false the inputs are scaled by
1 / (1 - rate)
and masked, whereas if true, no mask is applied and the inputs are returned as is.rng_collection – the rng collection name to use when requesting an rng key.
rngs – rng key.