Skip to main content
Ctrl+K
 - Home
  • Quick start
  • Flax Basics
  • Guides
    • Flax fundamentals
      • JAX 101
      • Flax Basics
      • Managing Parameters and State
      • setup vs compact
      • Dealing with Flax Module arguments
    • Data preprocessing
      • Processing the entire Dataset
      • Loading datasets
    • Training techniques
      • Batch normalization
      • Dropout
      • Learning rate scheduling
      • Transfer learning
      • Save and load checkpoints
    • Parallel training
      • Ensembling on multiple devices
      • Scale up Flax Modules on multiple devices
    • Model inspection
      • Model surgery
      • Extracting intermediate values
    • Converting and upgrading
      • Migrating from Haiku to Flax
      • Convert PyTorch models to Flax
      • Migrate checkpointing to Orbax
      • Upgrading my codebase to Optax
      • Upgrading my codebase to Linen
      • RNNCellBase Upgrade Guide
      • Migrate to regular dicts
    • The Sharp Bits
  • Examples
    • Core examples
    • Google Research examples
    • Repositories that use Flax
    • Community examples
  • Glossary
  • Frequently Asked Questions (FAQ)
  • Developer notes
    • The Flax Module lifecycle
    • Lifted transformations
    • FLIPs
  • The Flax philosophy
  • How to contribute
  • API Reference
    • flax.config package
    • flax.core.frozen_dict package
    • flax.cursor package
    • flax.errors package
    • flax.jax_utils package
    • flax.linen
      • Module
      • Init/Apply
        • flax.linen.apply
        • flax.linen.init
        • flax.linen.init_with_output
      • Layers
        • flax.linen.Dense
        • flax.linen.DenseGeneral
        • flax.linen.Conv
        • flax.linen.ConvTranspose
        • flax.linen.ConvLocal
        • flax.linen.Embed
        • flax.linen.BatchNorm
        • flax.linen.LayerNorm
        • flax.linen.GroupNorm
        • flax.linen.RMSNorm
        • flax.linen.SpectralNorm
        • flax.linen.WeightNorm
        • flax.linen.Sequential
        • flax.linen.Dropout
        • flax.linen.SelfAttention
        • flax.linen.MultiHeadDotProductAttention
        • flax.linen.RNNCellBase
        • flax.linen.LSTMCell
        • flax.linen.OptimizedLSTMCell
        • flax.linen.GRUCell
        • flax.linen.RNN
        • flax.linen.Bidirectional
        • flax.linen.max_pool
        • flax.linen.avg_pool
        • flax.linen.pool
        • flax.linen.dot_product_attention_weights
        • flax.linen.dot_product_attention
        • flax.linen.make_attention_mask
        • flax.linen.make_causal_mask
      • Activation functions
        • flax.linen.activation.PReLU
        • flax.linen.activation.celu
        • flax.linen.activation.elu
        • flax.linen.activation.gelu
        • flax.linen.activation.glu
        • flax.linen.activation.hard_sigmoid
        • flax.linen.activation.hard_silu
        • flax.linen.activation.hard_swish
        • flax.linen.activation.hard_tanh
        • flax.linen.activation.leaky_relu
        • flax.linen.activation.log_sigmoid
        • flax.linen.activation.log_softmax
        • flax.linen.activation.logsumexp
        • flax.linen.activation.one_hot
        • flax.linen.activation.relu
        • flax.linen.activation.relu6
        • flax.linen.activation.selu
        • flax.linen.activation.sigmoid
        • flax.linen.activation.silu
        • flax.linen.activation.soft_sign
        • flax.linen.activation.softmax
        • flax.linen.activation.softplus
        • flax.linen.activation.standardize
        • flax.linen.activation.swish
        • flax.linen.activation.tanh
      • Initializers
        • flax.linen.initializers.constant
        • flax.linen.initializers.delta_orthogonal
        • flax.linen.initializers.glorot_normal
        • flax.linen.initializers.glorot_uniform
        • flax.linen.initializers.he_normal
        • flax.linen.initializers.he_uniform
        • flax.linen.initializers.kaiming_normal
        • flax.linen.initializers.kaiming_uniform
        • flax.linen.initializers.lecun_normal
        • flax.linen.initializers.lecun_uniform
        • flax.linen.initializers.normal
        • flax.linen.initializers.truncated_normal
        • flax.linen.initializers.ones
        • flax.linen.initializers.ones_init
        • flax.linen.initializers.orthogonal
        • flax.linen.initializers.uniform
        • flax.linen.initializers.variance_scaling
        • flax.linen.initializers.xavier_normal
        • flax.linen.initializers.xavier_uniform
        • flax.linen.initializers.zeros
        • flax.linen.initializers.zeros_init
      • Transformations
        • flax.linen.vmap
        • flax.linen.scan
        • flax.linen.jit
        • flax.linen.remat
        • flax.linen.remat_scan
        • flax.linen.map_variables
        • flax.linen.jvp
        • flax.linen.vjp
        • flax.linen.grad
        • flax.linen.value_and_grad
        • flax.linen.custom_vjp
        • flax.linen.while_loop
        • flax.linen.cond
        • flax.linen.switch
      • Inspection
        • flax.linen.tabulate
      • Variable dictionary
        • flax.linen.Variable
      • SPMD
        • flax.linen.Partitioned
        • flax.linen.with_partitioning
        • flax.linen.get_partition_spec
        • flax.linen.get_sharding
        • flax.linen.LogicallyPartitioned
        • flax.linen.logical_axis_rules
        • flax.linen.set_logical_axis_rules
        • flax.linen.get_logical_axis_rules
        • flax.linen.logical_to_mesh_axes
        • flax.linen.logical_to_mesh
        • flax.linen.logical_to_mesh_sharding
        • flax.linen.with_logical_constraint
        • flax.linen.with_logical_partitioning
      • Decorators
        • flax.linen.compact
        • flax.linen.nowrap
      • Profiling
        • flax.linen.enable_named_call
        • flax.linen.disable_named_call
        • flax.linen.override_named_call
    • flax.serialization package
    • flax.struct package
    • flax.traceback_util package
    • flax.training package
    • flax.traverse_util package
  • .rst

Parallel training

Parallel training#

  • Ensembling on multiple devices
  • Scale up Flax Modules on multiple devices

previous

Save and load checkpoints

next

Ensembling on multiple devices

By The Flax authors

© Copyright 2023, The Flax authors.