nn# Neural network layers and activation functions used in NNX Module’s. See the NNX page for more details. Activation functions celu() elu() gelu() glu() hard_sigmoid() hard_silu() hard_swish() hard_tanh() leaky_relu() log_sigmoid() log_softmax() logsumexp() one_hot() relu() selu() sigmoid() silu() soft_sign() softmax() softplus() standardize() swish() tanh() Attention MultiHeadAttention MultiHeadAttention.__call__() MultiHeadAttention.init_cache() combine_masks() dot_product_attention() make_attention_mask() make_causal_mask() Dtypes canonicalize_dtype() promote_dtype() Initializers constant() delta_orthogonal() glorot_normal() glorot_uniform() he_normal() he_uniform() kaiming_normal() kaiming_uniform() lecun_normal() lecun_uniform() normal() truncated_normal() ones() ones_init() orthogonal() uniform() variance_scaling() xavier_normal() xavier_uniform() zeros() zeros_init() Linear Conv Conv.__call__() ConvTranspose ConvTranspose.__call__() Embed Embed.__call__() Embed.attend() Linear Linear.__call__() LinearGeneral LinearGeneral.__call__() Einsum Einsum.__call__() LoRA LoRA LoRA.__call__() LoRALinear LoRALinear.__call__() Normalization BatchNorm BatchNorm.__call__() LayerNorm LayerNorm.__call__() RMSNorm RMSNorm.__call__() GroupNorm GroupNorm.__call__() Recurrent LSTMCell LSTMCell.__call__() LSTMCell.initialize_carry() OptimizedLSTMCell OptimizedLSTMCell.__call__() OptimizedLSTMCell.initialize_carry() SimpleCell SimpleCell.__call__() SimpleCell.initialize_carry() GRUCell GRUCell.__call__() GRUCell.initialize_carry() RNN RNN.__call__() Bidirectional Bidirectional.__call__() flip_sequences() Stochastic Dropout