Guides# Flax fundamentals JAX 101 Flax Basics Managing Parameters and State setup vs compact Dealing with Flax Module arguments Randomness and PRNGs in Flax Data preprocessing Processing the entire Dataset Loading datasets Training techniques Batch normalization Dropout Learning rate scheduling Transfer learning Save and load checkpoints Parallel training Ensembling on multiple devices Scale up Flax Modules on multiple devices Model inspection Model surgery Extracting intermediate values Extracting gradients of intermediate values Converting and upgrading Migrating from Haiku to Flax Convert PyTorch models to Flax Migrate checkpointing to Orbax Upgrading my codebase to Optax Upgrading my codebase to Linen RNNCellBase Upgrade Guide Migrate to regular dicts The Sharp Bits 🔪 flax.linen.Dropout layer and randomness