Skip to main content
Ctrl+K
Logo image
  • Quick start
  • Guides
    • Flax fundamentals
      • JAX 101
      • Flax Basics
      • Managing Parameters and State
      • setup vs compact
      • Dealing with Flax Module arguments
    • Data preprocesing
      • Processing the entire Dataset
    • Training techniques
      • Batch normalization
      • Dropout
      • Learning rate scheduling
      • Transfer learning
      • Save and load checkpoints
    • Parallel training
      • Ensembling on multiple devices
      • Scale up Flax Modules on multiple devices with pjit
    • Model inspection
      • Model surgery
      • Extracting intermediate values
    • Converting and upgrading
      • Convert PyTorch models to Flax
      • Migrate checkpointing to Orbax
      • Upgrading my codebase to Optax
      • Upgrading my codebase to Linen
    • The Sharp Bits
  • Examples
    • Core examples
    • Google Research examples
    • Repositories that use Flax
    • Community examples
  • Glossary
  • Developer notes
    • The Flax Module lifecycle
    • Lifted transformations
    • FLIPs
  • The Flax philosophy
  • How to contribute
  • API Reference
    • flax.linen package
      • flax.linen.enable_named_call
      • flax.linen.disable_named_call
      • flax.linen.override_named_call
      • flax.linen.tabulate
      • flax.linen.vmap
      • flax.linen.scan
      • flax.linen.jit
      • flax.linen.remat
      • flax.linen.remat_scan
      • flax.linen.map_variables
      • flax.linen.jvp
      • flax.linen.vjp
      • flax.linen.custom_vjp
      • flax.linen.while_loop
      • flax.linen.cond
      • flax.linen.switch
      • flax.linen.Partitioned
      • flax.linen.with_partitioning
      • flax.linen.get_partition_spec
      • flax.linen.LogicallyPartitioned
      • flax.linen.logical_axis_rules
      • flax.linen.set_logical_axis_rules
      • flax.linen.get_logical_axis_rules
      • flax.linen.logical_to_mesh_axes
      • flax.linen.logical_to_mesh
      • flax.linen.with_logical_constraint
      • flax.linen.with_logical_partitioning
      • flax.linen.Dense
      • flax.linen.DenseGeneral
      • flax.linen.Conv
      • flax.linen.ConvTranspose
      • flax.linen.ConvLocal
      • flax.linen.Embed
      • flax.linen.BatchNorm
      • flax.linen.LayerNorm
      • flax.linen.GroupNorm
      • flax.linen.max_pool
      • flax.linen.avg_pool
      • flax.linen.pool
      • flax.linen.activation.PReLU
      • flax.linen.activation.celu
      • flax.linen.activation.elu
      • flax.linen.activation.gelu
      • flax.linen.activation.glu
      • flax.linen.activation.hard_sigmoid
      • flax.linen.activation.hard_silu
      • flax.linen.activation.hard_swish
      • flax.linen.activation.hard_tanh
      • flax.linen.activation.leaky_relu
      • flax.linen.activation.log_sigmoid
      • flax.linen.activation.log_softmax
      • flax.linen.activation.logsumexp
      • flax.linen.activation.one_hot
      • flax.linen.activation.relu
      • flax.linen.activation.relu6
      • flax.linen.activation.selu
      • flax.linen.activation.sigmoid
      • flax.linen.activation.silu
      • flax.linen.activation.soft_sign
      • flax.linen.activation.softmax
      • flax.linen.activation.softplus
      • flax.linen.activation.standardize
      • flax.linen.activation.swish
      • flax.linen.activation.tanh
      • flax.linen.initializers.constant
      • flax.linen.initializers.delta_orthogonal
      • flax.linen.initializers.glorot_normal
      • flax.linen.initializers.glorot_uniform
      • flax.linen.initializers.he_normal
      • flax.linen.initializers.he_uniform
      • flax.linen.initializers.kaiming_normal
      • flax.linen.initializers.kaiming_uniform
      • flax.linen.initializers.lecun_normal
      • flax.linen.initializers.lecun_uniform
      • flax.linen.initializers.normal
      • flax.linen.initializers.ones
      • flax.linen.initializers.ones_init
      • flax.linen.initializers.orthogonal
      • flax.linen.initializers.uniform
      • flax.linen.initializers.variance_scaling
      • flax.linen.initializers.xavier_normal
      • flax.linen.initializers.xavier_uniform
      • flax.linen.initializers.zeros
      • flax.linen.initializers.zeros_init
      • flax.linen.Sequential
      • flax.linen.dot_product_attention_weights
      • flax.linen.dot_product_attention
      • flax.linen.make_attention_mask
      • flax.linen.make_causal_mask
      • flax.linen.SelfAttention
      • flax.linen.MultiHeadDotProductAttention
      • flax.linen.Dropout
      • flax.linen.LSTMCell
      • flax.linen.OptimizedLSTMCell
      • flax.linen.GRUCell
      • flax.linen.RNNCellBase
      • flax.linen.RNN
    • flax.serialization package
    • flax.core.frozen_dict package
    • flax.struct package
    • flax.jax_utils package
    • flax.traceback_util package
    • flax.traverse_util package
    • flax.training package
    • flax.config package
    • flax.error package
  • .rst

Guides

Guides#

  • Flax fundamentals
    • JAX 101
    • Flax Basics
    • Managing Parameters and State
    • setup vs compact
    • Dealing with Flax Module arguments
  • Data preprocesing
    • Processing the entire Dataset
  • Training techniques
    • Batch normalization
    • Dropout
    • Learning rate scheduling
    • Transfer learning
    • Save and load checkpoints
  • Parallel training
    • Ensembling on multiple devices
    • Scale up Flax Modules on multiple devices with pjit
  • Model inspection
    • Model surgery
    • Extracting intermediate values
    • Extracting gradients of intermediate values
  • Converting and upgrading
    • Convert PyTorch models to Flax
    • Migrate checkpointing to Orbax
    • Upgrading my codebase to Optax
    • Upgrading my codebase to Linen
  • The Sharp Bits
    • 🔪 flax.linen.Dropout layer and randomness

previous

Quickstart

next

Flax fundamentals

By The Flax authors

© Copyright 2023, The Flax authors.