Migrating# From PyTorch to Jax and Flax Quickstart with Jax Arrays Building Neural Networks PyTorch layers vs NNX layers Training Neural Networks Porting PyTorch weights to NNX NNX 0.10 to NNX 0.11 Using Rngs in NNX Transforms Loading Checkpoints with RNGs Optimizer Updates Pytrees containing NNX Objects Flax Linen to Flax NNX Basic Module definition Variable creation Training step and compilation Collections and variable types Using multiple methods Transformations Scan over layers Using TrainState in Flax NNX Haiku to Flax NNX Basic Module definition Variable creation Training step and compilation Handling non-parameter states Using multiple methods Transformations Scan over layers Top-level Haiku functions vs top-level Flax modules