flax.experimental.nnx# Experimental API. See the NNX page for more details. module Module Module.sow() Module.iter_modules() Module.eval() Module.init Module.is_initializing() Module.iter_modules() Module.set_attributes() Module.train() nn Activation functions celu() elu() gelu() glu() hard_sigmoid() hard_silu() hard_swish() hard_tanh() leaky_relu() log_sigmoid() log_softmax() logsumexp() one_hot() relu() selu() sigmoid() silu() soft_sign() softmax() softplus() standardize() swish() tanh() Attention MultiHeadAttention combine_masks() dot_product_attention() make_attention_mask() make_causal_mask() Initializers constant() delta_orthogonal() glorot_normal() glorot_uniform() he_normal() he_uniform() kaiming_normal() kaiming_uniform() lecun_normal() lecun_uniform() normal() truncated_normal() ones() ones_init() orthogonal() uniform() variance_scaling() xavier_normal() xavier_uniform() zeros() zeros_init() Linear Conv Embed Linear LinearGeneral Einsum Normalization BatchNorm LayerNorm RMSNorm Stochastic Dropout rnglib Rngs RngStream spmd get_partition_spec() get_named_sharding() with_partitioning() with_sharding_constraint() training Metrics Metric Average Accuracy MultiMetric Optimizer Optimizer transforms Remat Scan Vmap grad() jit() remat() scan() value_and_grad() vmap() variables BatchStat Cache Empty Intermediate Param Variable VariableMetadata with_metadata() helpers Dict List Sequential TrainState TrainState.replace() visualization display()