Index _ | A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Z _ __call__() (flax.linen.BatchNorm method) (flax.linen.Bidirectional method) (flax.linen.Conv method) (flax.linen.ConvLocal method) (flax.linen.ConvTranspose method) (flax.linen.Dense method) (flax.linen.DenseGeneral method) (flax.linen.Dropout method) (flax.linen.Embed method) (flax.linen.GroupNorm method) (flax.linen.GRUCell method) (flax.linen.LayerNorm method) (flax.linen.LSTMCell method) (flax.linen.MultiHeadDotProductAttention method) (flax.linen.OptimizedLSTMCell method) (flax.linen.RNN method) (flax.linen.RNNCellBase method) (flax.linen.SelfAttention method) (flax.linen.Sequential method) __init__() (flax.linen.activation.PReLU method) (flax.linen.LogicallyPartitioned method) (flax.linen.Partitioned method) (flax.traverse_util.ModelParamTraversal method) __setattr__() (flax.linen.Module method) A activation_fn (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.OptimizedLSTMCell attribute) AlreadyExistsError apply() (flax.linen.Module method) (in module flax.linen) apply_gradients() (flax.training.train_state.TrainState method) ApplyModuleInvalidMethodError ApplyScopeInvalidVariablesStructureError ApplyScopeInvalidVariablesTypeError AssignSubModuleError attention_fn (flax.linen.MultiHeadDotProductAttention attribute) avg_pool() (in module flax.linen) axis (flax.linen.BatchNorm attribute) (flax.linen.DenseGeneral attribute) axis_index_groups (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) axis_name (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) B batch_dims (flax.linen.DenseGeneral attribute) BatchNorm (class in flax.linen) best_metric (flax.training.early_stopping.EarlyStopping attribute) bias_init (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) Bidirectional (class in flax.linen) bind() (flax.linen.Module method) Bound Module broadcast_dims (flax.linen.Dropout attribute) broadcast_dropout (flax.linen.MultiHeadDotProductAttention attribute) C CallCompactUnboundModuleError CallSetupUnboundModuleError CallUnbindOnUnboundModuleError cell (flax.linen.RNN attribute) cell_size (flax.linen.RNN attribute) celu() (in module flax.linen.activation) Compact / Non-compact Module compact() (in module flax.linen) compose() (flax.traverse_util.Traversal method) cond() (in module flax.linen) constant() (in module flax.linen.initializers) Conv (class in flax.linen) convert_pre_linen() (in module flax.training.checkpoints) ConvLocal (class in flax.linen) ConvTranspose (class in flax.linen) copy() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) create() (flax.training.train_state.TrainState class method) create_constant_learning_rate_schedule() (in module flax.training.lr_schedule) create_cosine_learning_rate_schedule() (in module flax.training.lr_schedule) create_stepped_learning_rate_schedule() (in module flax.training.lr_schedule) custom_vjp() (in module flax.linen) D dataclass() (in module flax.struct) decode (flax.linen.MultiHeadDotProductAttention attribute) define_bool_state() (in module flax.configurations) delta_orthogonal() (in module flax.linen.initializers) Dense (class in flax.linen) DenseGeneral (class in flax.linen) DescriptorAttributeError deterministic (flax.linen.Dropout attribute) (flax.linen.MultiHeadDotProductAttention attribute) disable_named_call() (in module flax.linen) dot_product_attention() (in module flax.linen) dot_product_attention_weights() (in module flax.linen) Dropout (class in flax.linen) dropout_rate (flax.linen.MultiHeadDotProductAttention attribute) dtype (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Embed attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) E each() (flax.traverse_util.Traversal method) EarlyStopping (class in flax.training.early_stopping) elu() (in module flax.linen.activation) Embed (class in flax.linen) embedding_init (flax.linen.Embed attribute) enable_named_call() (in module flax.linen) epsilon (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) F feature_axes (flax.linen.LayerNorm attribute) feature_group_count (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) features (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Embed attribute) filter() (flax.traverse_util.Traversal method) flatten_dict() (in module flax.traverse_util) flax.configurations module flax.core.variables module flax.errors module flax.jax_utils module flax.linen module, [1] flax.linen.activation module flax.linen.initializers module flax.linen.spmd module flax.linen.transforms module flax.serialization module flax.struct module flax.traceback_util module flax.training.checkpoints module flax.training.lr_schedule module flax.traverse_util module Folding in freeze() (in module flax.core.frozen_dict) from_bytes() (in module flax.serialization) from_state_dict() (in module flax.serialization) FrozenDict (class in flax.core.frozen_dict) Functional core G gate_fn (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.OptimizedLSTMCell attribute) gelu() (in module flax.linen.activation) get_logical_axis_rules() (in module flax.linen) get_metrics() (in module flax.training.common_utils) get_partition_spec() (in module flax.linen) get_sharding() (in module flax.linen) glorot_normal() (in module flax.linen.initializers) glorot_uniform() (in module flax.linen.initializers) glu() (in module flax.linen.activation) group_size (flax.linen.GroupNorm attribute) GroupNorm (class in flax.linen) GRUCell (class in flax.linen) H hard_sigmoid() (in module flax.linen.activation) hard_silu() (in module flax.linen.activation) hard_swish() (in module flax.linen.activation) hard_tanh() (in module flax.linen.activation) he_normal() (in module flax.linen.initializers) he_uniform() (in module flax.linen.initializers) hide_flax_in_tracebacks() (in module flax.traceback_util) I IncorrectPostInitOverrideError init() (flax.linen.Module method) (in module flax.linen) init_with_output() (flax.linen.Module method) (in module flax.linen) input_dilation (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) InvalidCheckpointError InvalidFilterError InvalidInstanceModuleError InvalidRngError InvalidScopeError is_initializing() (flax.linen.Module method) iterate() (flax.traverse_util.Traversal method) (flax.traverse_util.TraverseAttr method) (flax.traverse_util.TraverseCompose method) (flax.traverse_util.TraverseEach method) (flax.traverse_util.TraverseFilter method) (flax.traverse_util.TraverseId method) (flax.traverse_util.TraverseItem method) (flax.traverse_util.TraverseMerge method) (flax.traverse_util.TraverseTree method) J JaxTransformError jit() (in module flax.linen) jvp() (in module flax.linen) K kaiming_normal() (in module flax.linen.initializers) kaiming_uniform() (in module flax.linen.initializers) keep_order (flax.linen.RNN attribute) kernel_dilation (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) kernel_init (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) kernel_size (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) L latest_checkpoint() (in module flax.training.checkpoints) LayerNorm (class in flax.linen) Lazy initialization LazyInitError leaky_relu() (in module flax.linen.activation) lecun_normal() (in module flax.linen.initializers) lecun_uniform() (in module flax.linen.initializers) Lifted transformation log_sigmoid() (in module flax.linen.activation) log_softmax() (in module flax.linen.activation) logical_axis_rules() (in module flax.linen) logical_to_mesh() (in module flax.linen) logical_to_mesh_axes() (in module flax.linen) logical_to_mesh_sharding() (in module flax.linen) LogicallyPartitioned (class in flax.linen) logsumexp() (in module flax.linen.activation) LSTMCell (class in flax.linen) M make_attention_mask() (in module flax.linen) make_causal_mask() (in module flax.linen) make_rng() (flax.linen.Module method) map_variables() (in module flax.linen) mask (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) max_pool() (in module flax.linen) merge() (flax.traverse_util.Traversal method) min_delta (flax.training.early_stopping.EarlyStopping attribute) ModelParamTraversal (class in flax.traverse_util) ModifyScopeVariableError Module module (class in flax.linen) flax.configurations flax.core.variables flax.errors flax.jax_utils flax.linen, [1] flax.linen.activation flax.linen.initializers flax.linen.spmd flax.linen.transforms flax.serialization flax.struct flax.traceback_util flax.training.checkpoints flax.training.lr_schedule flax.traverse_util momentum (flax.linen.BatchNorm attribute) MPACheckpointingRequiredError MPARestoreDataCorruptedError MPARestoreTargetRequiredError MPARestoreTypeNotMatchError msgpack_restore() (in module flax.serialization) msgpack_serialize() (in module flax.serialization) MultiHeadDotProductAttention (class in flax.linen) MultipleMethodsCompactError N NameInUseError negative_slope_init (flax.linen.activation.PReLU attribute) normal() (in module flax.linen.initializers) nowrap() (in module flax.linen) num_embeddings (flax.linen.Embed attribute) num_groups (flax.linen.GroupNorm attribute) num_heads (flax.linen.MultiHeadDotProductAttention attribute) O one_hot() (in module flax.linen.activation) onehot() (in module flax.training.common_utils) ones() (in module flax.linen.initializers) ones_init() (in module flax.linen.initializers) OptimizedLSTMCell (class in flax.linen) orthogonal() (in module flax.linen.initializers) out_features (flax.linen.MultiHeadDotProductAttention attribute) override_named_call() (in module flax.linen) P pad_shard_unpad() (in module flax.jax_utils) padding (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) param() (flax.linen.Module method) param_dtype (flax.linen.activation.PReLU attribute) (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.Embed attribute) (flax.linen.GroupNorm attribute) (flax.linen.GRUCell attribute) (flax.linen.LayerNorm attribute) (flax.linen.LSTMCell attribute) (flax.linen.MultiHeadDotProductAttention attribute) (flax.linen.OptimizedLSTMCell attribute) Params / parameters partial_eval_by_shape() (in module flax.jax_utils) Partitioned (class in flax.linen) PartitioningUnspecifiedError path_aware_map() (in module flax.traverse_util) patience (flax.training.early_stopping.EarlyStopping attribute) patience_count (flax.training.early_stopping.EarlyStopping attribute) perturb() (flax.linen.Module method) pmean() (in module flax.jax_utils) pool() (in module flax.linen) pop() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) precision (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.MultiHeadDotProductAttention attribute) prefetch_to_device() (in module flax.jax_utils) PReLU (class in flax.linen.activation) pretty_repr() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) PyTreeNode (class in flax.struct) Q qkv_features (flax.linen.MultiHeadDotProductAttention attribute) R rate (flax.linen.Dropout attribute) recurrent_kernel_init (flax.linen.GRUCell attribute) (flax.linen.LSTMCell attribute) (flax.linen.OptimizedLSTMCell attribute) reduction_axes (flax.linen.LayerNorm attribute) register_serialization_state() (in module flax.serialization) relu (in module flax.linen.activation) relu6 (in module flax.linen.activation) remat() (in module flax.linen) remat_scan() (in module flax.linen) replicate() (in module flax.jax_utils) ReservedModuleAttributeError restore_checkpoint() (in module flax.training.checkpoints) return_carry (flax.linen.RNN attribute) reverse (flax.linen.RNN attribute) RNG sequences rng_collection (flax.linen.Dropout attribute) RNN (class in flax.linen) RNNCellBase (class in flax.linen) S save_checkpoint() (in module flax.training.checkpoints) save_checkpoint_multiprocess() (in module flax.training.checkpoints) scale_init (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) scan() (in module flax.linen) Scope ScopeCollectionNotFound ScopeParamNotFoundError ScopeParamShapeError ScopeVariableNotFoundError SelfAttention (class in flax.linen) selu() (in module flax.linen.activation) Sequential (class in flax.linen) set() (flax.traverse_util.Traversal method) set_logical_axis_rules() (in module flax.linen) SetAttributeFrozenModuleError SetAttributeInModuleSetupError setup() (flax.linen.Module method) Shape inference shard() (in module flax.training.common_utils) shard_prng_key() (in module flax.training.common_utils) should_stop (flax.training.early_stopping.EarlyStopping attribute) show_flax_in_tracebacks() (in module flax.traceback_util) sigmoid() (in module flax.linen.activation) silu() (in module flax.linen.activation) soft_sign() (in module flax.linen.activation) softmax() (in module flax.linen.activation) softplus() (in module flax.linen.activation) sow() (flax.linen.Module method) split_rngs (flax.linen.RNN attribute) stack_forest() (in module flax.training.common_utils) standardize() (in module flax.linen.activation) static_bool_env() (in module flax.configurations) strides (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) swish() (in module flax.linen.activation) switch() (in module flax.linen) T tabulate() (flax.linen.Module method) (in module flax.linen) tanh() (in module flax.linen.activation) time_major (flax.linen.RNN attribute) to_bytes() (in module flax.serialization) to_state_dict() (in module flax.serialization) TrainState (class in flax.training.train_state) TransformedMethodReturnValueError TransformTargetError transpose_kernel (flax.linen.ConvTranspose attribute) Traversal (class in flax.traverse_util) TraverseAttr (class in flax.traverse_util) TraverseCompose (class in flax.traverse_util) TraverseEach (class in flax.traverse_util) TraverseFilter (class in flax.traverse_util) TraverseId (class in flax.traverse_util) TraverseItem (class in flax.traverse_util) TraverseMerge (class in flax.traverse_util) TraverseTree (class in flax.traverse_util) tree() (flax.traverse_util.Traversal method) U unbind() (flax.linen.Module method) unflatten_dict() (in module flax.traverse_util) unfreeze() (flax.core.frozen_dict.FrozenDict method) (in module flax.core.frozen_dict) uniform() (in module flax.linen.initializers) unreplicate() (in module flax.jax_utils) unroll (flax.linen.RNN attribute) update() (flax.training.early_stopping.EarlyStopping method) (flax.traverse_util.Traversal method) (flax.traverse_util.TraverseAttr method) (flax.traverse_util.TraverseCompose method) (flax.traverse_util.TraverseEach method) (flax.traverse_util.TraverseFilter method) (flax.traverse_util.TraverseId method) (flax.traverse_util.TraverseItem method) (flax.traverse_util.TraverseMerge method) (flax.traverse_util.TraverseTree method) use_bias (flax.linen.BatchNorm attribute) (flax.linen.Conv attribute) (flax.linen.ConvLocal attribute) (flax.linen.ConvTranspose attribute) (flax.linen.Dense attribute) (flax.linen.DenseGeneral attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) (flax.linen.MultiHeadDotProductAttention attribute) use_running_average (flax.linen.BatchNorm attribute) use_scale (flax.linen.BatchNorm attribute) (flax.linen.GroupNorm attribute) (flax.linen.LayerNorm attribute) V Variable (class in flax.core.variables) Variable collections Variable dictionary variable() (flax.linen.Module method) variable_axes (flax.linen.RNN attribute) variable_broadcast (flax.linen.RNN attribute) variable_carry (flax.linen.RNN attribute) variables (flax.linen.Module property) variance_scaling() (in module flax.linen.initializers) vjp() (in module flax.linen) vmap() (in module flax.linen) W while_loop() (in module flax.linen) with_logical_constraint() (in module flax.linen) with_logical_partitioning() (in module flax.linen) with_partitioning() (in module flax.linen) X xavier_normal() (in module flax.linen.initializers) xavier_uniform() (in module flax.linen.initializers) Z zeros() (in module flax.linen.initializers) zeros_init() (in module flax.linen.initializers)