nn#
Neural network layers and activation functions used in NNX Module
’s.
See the NNX page for more details.
- Activation functions
- Attention
MultiHeadAttention
MultiHeadAttention.num_heads
MultiHeadAttention.in_features
MultiHeadAttention.qkv_features
MultiHeadAttention.out_features
MultiHeadAttention.dtype
MultiHeadAttention.param_dtype
MultiHeadAttention.broadcast_dropout
MultiHeadAttention.dropout_rate
MultiHeadAttention.deterministic
MultiHeadAttention.precision
MultiHeadAttention.kernel_init
MultiHeadAttention.out_kernel_init
MultiHeadAttention.bias_init
MultiHeadAttention.out_bias_init
MultiHeadAttention.use_bias
MultiHeadAttention.attention_fn
MultiHeadAttention.decode
MultiHeadAttention.normalize_qk
MultiHeadAttention.rngs
MultiHeadAttention.__call__()
MultiHeadAttention.init_cache()
combine_masks()
dot_product_attention()
make_attention_mask()
make_causal_mask()
- Dtypes
- Initializers
- Linear
Conv
ConvTranspose
ConvTranspose.in_features
ConvTranspose.out_features
ConvTranspose.kernel_size
ConvTranspose.strides
ConvTranspose.padding
ConvTranspose.kernel_dilation
ConvTranspose.use_bias
ConvTranspose.mask
ConvTranspose.dtype
ConvTranspose.param_dtype
ConvTranspose.precision
ConvTranspose.kernel_init
ConvTranspose.bias_init
ConvTranspose.transpose_kernel
ConvTranspose.rngs
ConvTranspose.__call__()
Embed
Linear
LinearGeneral
Einsum
- LoRA
- Normalization
BatchNorm
BatchNorm.num_features
BatchNorm.use_running_average
BatchNorm.axis
BatchNorm.momentum
BatchNorm.epsilon
BatchNorm.dtype
BatchNorm.param_dtype
BatchNorm.use_bias
BatchNorm.use_scale
BatchNorm.bias_init
BatchNorm.scale_init
BatchNorm.axis_name
BatchNorm.axis_index_groups
BatchNorm.use_fast_variance
BatchNorm.rngs
BatchNorm.__call__()
LayerNorm
LayerNorm.num_features
LayerNorm.epsilon
LayerNorm.dtype
LayerNorm.param_dtype
LayerNorm.use_bias
LayerNorm.use_scale
LayerNorm.bias_init
LayerNorm.scale_init
LayerNorm.reduction_axes
LayerNorm.feature_axes
LayerNorm.axis_name
LayerNorm.axis_index_groups
LayerNorm.use_fast_variance
LayerNorm.rngs
LayerNorm.__call__()
RMSNorm
GroupNorm
GroupNorm.num_features
GroupNorm.num_groups
GroupNorm.group_size
GroupNorm.epsilon
GroupNorm.dtype
GroupNorm.param_dtype
GroupNorm.use_bias
GroupNorm.use_scale
GroupNorm.bias_init
GroupNorm.scale_init
GroupNorm.reduction_axes
GroupNorm.axis_name
GroupNorm.axis_index_groups
GroupNorm.use_fast_variance
GroupNorm.rngs
GroupNorm.__call__()
- Recurrent
- Stochastic