Skip to main content
Ctrl+K
Logo image
  • Quick start
  • Guides
    • Flax fundamentals
      • JAX 101
      • Flax Basics
      • Managing Parameters and State
      • setup vs compact
      • Dealing with Flax Module arguments
    • Data preprocesing
      • Processing the entire Dataset
    • Training techniques
      • Batch normalization
      • Dropout
      • Learning rate scheduling
      • Transfer learning
      • Save and load checkpoints
    • Parallel training
      • Ensembling on multiple devices
      • Scale up Flax Modules on multiple devices with pjit
    • Model inspection
      • Model surgery
      • Extracting intermediate values
    • Converting and upgrading
      • Convert PyTorch models to Flax
      • Migrate checkpointing to Orbax
      • Upgrading my codebase to Optax
      • Upgrading my codebase to Linen
    • The Sharp Bits
  • Examples
    • Core examples
    • Google Research examples
    • Repositories that use Flax
    • Community examples
  • Glossary
  • Developer notes
    • The Flax Module lifecycle
    • Lifted transformations
    • FLIPs
  • The Flax philosophy
  • How to contribute
  • API Reference
    • flax.linen package
      • flax.linen.enable_named_call
      • flax.linen.disable_named_call
      • flax.linen.override_named_call
      • flax.linen.tabulate
      • flax.linen.vmap
      • flax.linen.scan
      • flax.linen.jit
      • flax.linen.remat
      • flax.linen.remat_scan
      • flax.linen.map_variables
      • flax.linen.jvp
      • flax.linen.vjp
      • flax.linen.custom_vjp
      • flax.linen.while_loop
      • flax.linen.cond
      • flax.linen.switch
      • flax.linen.Partitioned
      • flax.linen.with_partitioning
      • flax.linen.get_partition_spec
      • flax.linen.LogicallyPartitioned
      • flax.linen.logical_axis_rules
      • flax.linen.set_logical_axis_rules
      • flax.linen.get_logical_axis_rules
      • flax.linen.logical_to_mesh_axes
      • flax.linen.logical_to_mesh
      • flax.linen.with_logical_constraint
      • flax.linen.with_logical_partitioning
      • flax.linen.Dense
      • flax.linen.DenseGeneral
      • flax.linen.Conv
      • flax.linen.ConvTranspose
      • flax.linen.ConvLocal
      • flax.linen.Embed
      • flax.linen.BatchNorm
      • flax.linen.LayerNorm
      • flax.linen.GroupNorm
      • flax.linen.max_pool
      • flax.linen.avg_pool
      • flax.linen.pool
      • flax.linen.activation.PReLU
      • flax.linen.activation.celu
      • flax.linen.activation.elu
      • flax.linen.activation.gelu
      • flax.linen.activation.glu
      • flax.linen.activation.hard_sigmoid
      • flax.linen.activation.hard_silu
      • flax.linen.activation.hard_swish
      • flax.linen.activation.hard_tanh
      • flax.linen.activation.leaky_relu
      • flax.linen.activation.log_sigmoid
      • flax.linen.activation.log_softmax
      • flax.linen.activation.logsumexp
      • flax.linen.activation.one_hot
      • flax.linen.activation.relu
      • flax.linen.activation.relu6
      • flax.linen.activation.selu
      • flax.linen.activation.sigmoid
      • flax.linen.activation.silu
      • flax.linen.activation.soft_sign
      • flax.linen.activation.softmax
      • flax.linen.activation.softplus
      • flax.linen.activation.standardize
      • flax.linen.activation.swish
      • flax.linen.activation.tanh
      • flax.linen.initializers.constant
      • flax.linen.initializers.delta_orthogonal
      • flax.linen.initializers.glorot_normal
      • flax.linen.initializers.glorot_uniform
      • flax.linen.initializers.he_normal
      • flax.linen.initializers.he_uniform
      • flax.linen.initializers.kaiming_normal
      • flax.linen.initializers.kaiming_uniform
      • flax.linen.initializers.lecun_normal
      • flax.linen.initializers.lecun_uniform
      • flax.linen.initializers.normal
      • flax.linen.initializers.ones
      • flax.linen.initializers.ones_init
      • flax.linen.initializers.orthogonal
      • flax.linen.initializers.uniform
      • flax.linen.initializers.variance_scaling
      • flax.linen.initializers.xavier_normal
      • flax.linen.initializers.xavier_uniform
      • flax.linen.initializers.zeros
      • flax.linen.initializers.zeros_init
      • flax.linen.Sequential
      • flax.linen.dot_product_attention_weights
      • flax.linen.dot_product_attention
      • flax.linen.make_attention_mask
      • flax.linen.make_causal_mask
      • flax.linen.SelfAttention
      • flax.linen.MultiHeadDotProductAttention
      • flax.linen.Dropout
      • flax.linen.LSTMCell
      • flax.linen.OptimizedLSTMCell
      • flax.linen.GRUCell
      • flax.linen.RNNCellBase
      • flax.linen.RNN
    • flax.serialization package
    • flax.core.frozen_dict package
    • flax.struct package
    • flax.jax_utils package
    • flax.traceback_util package
    • flax.traverse_util package
    • flax.training package
    • flax.config package
    • flax.error package
  • .rst

Examples

Examples#

  • Core examples
    • Image classification
    • Reinforcement learning
    • Natural language processing
    • Generative models
    • Graph modeling
    • Contributing to core Flax examples
  • Google Research examples
    • Attention
    • Computer vision
    • Diffusion
    • Domain adaptation
    • Generalization
    • Meta learning
    • Model efficiency
    • Neural rendering / NeRF
    • Optimization
    • Quantization
    • Reinforcement learning
    • Sequence models / Model parallelism
    • Simulation
  • Repositories that use Flax
    • 🤗 Hugging Face
    • 🥑 DALLE Mini
    • Scenic
    • Big Vision
    • T5X
  • Community examples
    • Models
    • Examples
    • Tutorials
    • Contributing policy

previous

🔪 Flax - The Sharp Bits 🔪

next

Core examples

By The Flax authors

© Copyright 2023, The Flax authors.