Skip to main content
Back to top
Ctrl
+
K
Quick start
Flax Basics
Guides
Flax fundamentals
JAX 101
Flax Basics
Managing Parameters and State
setup
vs
compact
Dealing with Flax Module arguments
Data preprocessing
Processing the entire Dataset
Loading datasets
Training techniques
Batch normalization
Dropout
Learning rate scheduling
Transfer learning
Save and load checkpoints
Parallel training
Ensembling on multiple devices
Scale up Flax Modules on multiple devices
Model inspection
Model surgery
Extracting intermediate values
Converting and upgrading
Migrating from Haiku to Flax
Convert PyTorch models to Flax
Migrate checkpointing to Orbax
Upgrading my codebase to Optax
Upgrading my codebase to Linen
RNNCellBase Upgrade Guide
Migrate to regular dicts
The Sharp Bits
Examples
Core examples
Google Research examples
Repositories that use Flax
Community examples
Glossary
Frequently Asked Questions (FAQ)
Developer notes
The Flax Module lifecycle
Lifted transformations
FLIPs
The Flax philosophy
How to contribute
API Reference
flax.config package
flax.core.frozen_dict package
flax.cursor package
flax.errors package
flax.jax_utils package
flax.linen
Module
Init/Apply
flax.linen.apply
flax.linen.init
flax.linen.init_with_output
Layers
flax.linen.Dense
flax.linen.DenseGeneral
flax.linen.Conv
flax.linen.ConvTranspose
flax.linen.ConvLocal
flax.linen.Embed
flax.linen.BatchNorm
flax.linen.LayerNorm
flax.linen.GroupNorm
flax.linen.RMSNorm
flax.linen.SpectralNorm
flax.linen.WeightNorm
flax.linen.Sequential
flax.linen.Dropout
flax.linen.SelfAttention
flax.linen.MultiHeadDotProductAttention
flax.linen.RNNCellBase
flax.linen.LSTMCell
flax.linen.OptimizedLSTMCell
flax.linen.GRUCell
flax.linen.RNN
flax.linen.Bidirectional
flax.linen.max_pool
flax.linen.avg_pool
flax.linen.pool
flax.linen.dot_product_attention_weights
flax.linen.dot_product_attention
flax.linen.make_attention_mask
flax.linen.make_causal_mask
Activation functions
flax.linen.activation.PReLU
flax.linen.activation.celu
flax.linen.activation.elu
flax.linen.activation.gelu
flax.linen.activation.glu
flax.linen.activation.hard_sigmoid
flax.linen.activation.hard_silu
flax.linen.activation.hard_swish
flax.linen.activation.hard_tanh
flax.linen.activation.leaky_relu
flax.linen.activation.log_sigmoid
flax.linen.activation.log_softmax
flax.linen.activation.logsumexp
flax.linen.activation.one_hot
flax.linen.activation.relu
flax.linen.activation.relu6
flax.linen.activation.selu
flax.linen.activation.sigmoid
flax.linen.activation.silu
flax.linen.activation.soft_sign
flax.linen.activation.softmax
flax.linen.activation.softplus
flax.linen.activation.standardize
flax.linen.activation.swish
flax.linen.activation.tanh
Initializers
flax.linen.initializers.constant
flax.linen.initializers.delta_orthogonal
flax.linen.initializers.glorot_normal
flax.linen.initializers.glorot_uniform
flax.linen.initializers.he_normal
flax.linen.initializers.he_uniform
flax.linen.initializers.kaiming_normal
flax.linen.initializers.kaiming_uniform
flax.linen.initializers.lecun_normal
flax.linen.initializers.lecun_uniform
flax.linen.initializers.normal
flax.linen.initializers.truncated_normal
flax.linen.initializers.ones
flax.linen.initializers.ones_init
flax.linen.initializers.orthogonal
flax.linen.initializers.uniform
flax.linen.initializers.variance_scaling
flax.linen.initializers.xavier_normal
flax.linen.initializers.xavier_uniform
flax.linen.initializers.zeros
flax.linen.initializers.zeros_init
Transformations
flax.linen.vmap
flax.linen.scan
flax.linen.jit
flax.linen.remat
flax.linen.remat_scan
flax.linen.map_variables
flax.linen.jvp
flax.linen.vjp
flax.linen.grad
flax.linen.value_and_grad
flax.linen.custom_vjp
flax.linen.while_loop
flax.linen.cond
flax.linen.switch
Inspection
flax.linen.tabulate
Variable dictionary
flax.linen.Variable
SPMD
flax.linen.Partitioned
flax.linen.with_partitioning
flax.linen.get_partition_spec
flax.linen.get_sharding
flax.linen.LogicallyPartitioned
flax.linen.logical_axis_rules
flax.linen.set_logical_axis_rules
flax.linen.get_logical_axis_rules
flax.linen.logical_to_mesh_axes
flax.linen.logical_to_mesh
flax.linen.logical_to_mesh_sharding
flax.linen.with_logical_constraint
flax.linen.with_logical_partitioning
Decorators
flax.linen.compact
flax.linen.nowrap
Profiling
flax.linen.enable_named_call
flax.linen.disable_named_call
flax.linen.override_named_call
flax.serialization package
flax.struct package
flax.traceback_util package
flax.training package
flax.traverse_util package
.rst
.pdf
Parallel training
Parallel training
#
Ensembling on multiple devices
Scale up Flax Modules on multiple devices