Toggle navigation sidebar
Toggle in-page Table of Contents
Flax documentation
Quickstart
Overview
Installation
Examples
Guided Tour
JAX for the Impatient
Flax Basics
Annotated MNIST
How do I ...?
Managing Parameters and State
Ensembling on multiple devices
Learning Rate Scheduling
Extracting intermediate values
Model Surgery
Convert PyTorch Models to Flax
Upgrading my Codebase to Optax
Upgrading my Codebase to Linen
Processing the entire Dataset
Design Notes
Dealing with Module Arguments
Lifted Transformations
Linen Design Principles
The Module lifecycle
Should I use
setup
or
nn.compact
?
FLIPs
Additional material
The Flax Philosophy
How to Contribute
API reference
flax.linen package
flax.linen.enable_named_call
flax.linen.disable_named_call
flax.linen.override_named_call
flax.linen.tabulate
flax.linen.vmap
flax.linen.scan
flax.linen.jit
flax.linen.remat
flax.linen.remat_scan
flax.linen.map_variables
flax.linen.jvp
flax.linen.vjp
flax.linen.custom_vjp
flax.linen.while_loop
flax.linen.cond
flax.linen.switch
flax.linen.Dense
flax.linen.DenseGeneral
flax.linen.Conv
flax.linen.ConvTranspose
flax.linen.ConvLocal
flax.linen.Embed
flax.linen.BatchNorm
flax.linen.LayerNorm
flax.linen.GroupNorm
flax.linen.max_pool
flax.linen.avg_pool
flax.linen.pool
flax.linen.celu
flax.linen.elu
flax.linen.gelu
flax.linen.glu
flax.linen.log_sigmoid
flax.linen.log_softmax
flax.linen.relu
flax.linen.sigmoid
flax.linen.soft_sign
flax.linen.softmax
flax.linen.softplus
flax.linen.swish
flax.linen.PReLU
flax.linen.Sequential
flax.linen.dot_product_attention_weights
flax.linen.dot_product_attention
flax.linen.make_attention_mask
flax.linen.make_causal_mask
flax.linen.SelfAttention
flax.linen.MultiHeadDotProductAttention
flax.linen.Dropout
flax.linen.LSTMCell
flax.linen.OptimizedLSTMCell
flax.linen.GRUCell
flax.serialization package
flax.core.frozen_dict package
flax.struct package
flax.jax_utils package
flax.traceback_util package
flax.traverse_util package
flax.training package
flax.config package
flax.error package
Index
_
|
A
|
B
|
C
|
D
|
E
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
Q
|
R
|
S
|
T
|
U
|
V
|
W
_
__call__() (flax.linen.BatchNorm method)
(flax.linen.Conv method)
(flax.linen.ConvLocal method)
(flax.linen.ConvTranspose method)
(flax.linen.Dense method)
(flax.linen.DenseGeneral method)
(flax.linen.Dropout method)
(flax.linen.Embed method)
(flax.linen.GroupNorm method)
(flax.linen.GRUCell method)
(flax.linen.LayerNorm method)
(flax.linen.LSTMCell method)
(flax.linen.MultiHeadDotProductAttention method)
(flax.linen.OptimizedLSTMCell method)
(flax.linen.SelfAttention method)
(flax.linen.Sequential method)
__init__() (flax.linen.PReLU method)
(flax.optim.AdaBelief method)
(flax.optim.Adadelta method)
(flax.optim.Adafactor method)
(flax.optim.Adagrad method)
(flax.optim.Adam method)
(flax.optim.GradientDescent method)
(flax.optim.LAMB method)
(flax.optim.LARS method)
(flax.optim.Momentum method)
(flax.optim.RMSProp method)
(flax.optim.WeightNorm method)
(flax.traverse_util.ModelParamTraversal method)
__setattr__() (flax.linen.Module method)
A
activation_fn (flax.linen.GRUCell attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.OptimizedLSTMCell attribute)
AdaBelief (class in flax.optim)
Adadelta (class in flax.optim)
Adafactor (class in flax.optim)
Adagrad (class in flax.optim)
Adam (class in flax.optim)
apply() (flax.linen.Module method)
(in module flax.linen)
apply_gradient() (flax.optim.Optimizer method)
(flax.optim.OptimizerDef method)
apply_gradients() (flax.training.train_state.TrainState method)
apply_param_gradient() (flax.optim.OptimizerDef method)
ApplyModuleInvalidMethodError
ApplyScopeInvalidVariablesStructureError
ApplyScopeInvalidVariablesTypeError
AssignSubModuleError
attention_fn (flax.linen.MultiHeadDotProductAttention attribute)
avg_pool() (in module flax.linen)
axis (flax.linen.BatchNorm attribute)
(flax.linen.DenseGeneral attribute)
axis_index_groups (flax.linen.BatchNorm attribute)
axis_name (flax.linen.BatchNorm attribute)
B
batch_dims (flax.linen.DenseGeneral attribute)
BatchNorm (class in flax.linen)
best_metric (flax.training.early_stopping.EarlyStopping attribute)
beta1 (flax.optim.AdaBelief attribute)
beta2 (flax.optim.AdaBelief attribute)
bias_init (flax.linen.BatchNorm attribute)
(flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.GRUCell attribute)
(flax.linen.LayerNorm attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
(flax.linen.OptimizedLSTMCell attribute)
bind() (flax.linen.Module method)
broadcast_dims (flax.linen.Dropout attribute)
broadcast_dropout (flax.linen.MultiHeadDotProductAttention attribute)
C
CallCompactUnboundModuleError
celu() (in module flax.linen)
compact() (in module flax.linen)
compose() (flax.traverse_util.Traversal method)
cond() (in module flax.linen)
Conv (class in flax.linen)
convert_pre_linen() (in module flax.training.checkpoints)
ConvLocal (class in flax.linen)
ConvTranspose (class in flax.linen)
copy() (flax.core.frozen_dict.FrozenDict method)
create() (flax.optim.OptimizerDef method)
(flax.training.train_state.TrainState class method)
create_constant_learning_rate_schedule() (in module flax.training.lr_schedule)
create_cosine_learning_rate_schedule() (in module flax.training.lr_schedule)
create_stepped_learning_rate_schedule() (in module flax.training.lr_schedule)
custom_vjp() (in module flax.linen)
D
dataclass() (in module flax.struct)
decode (flax.linen.MultiHeadDotProductAttention attribute)
Dense (class in flax.linen)
DenseGeneral (class in flax.linen)
deterministic (flax.linen.Dropout attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
disable_named_call() (in module flax.linen)
dot_product_attention() (in module flax.linen)
dot_product_attention_weights() (in module flax.linen)
Dropout (class in flax.linen)
dropout_rate (flax.linen.MultiHeadDotProductAttention attribute)
dtype (flax.linen.BatchNorm attribute)
(flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.Embed attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.GRUCell attribute)
(flax.linen.LayerNorm attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
(flax.linen.OptimizedLSTMCell attribute)
E
each() (flax.traverse_util.Traversal method)
EarlyStopping (class in flax.training.early_stopping)
elu() (in module flax.linen)
Embed (class in flax.linen)
embedding_init (flax.linen.Embed attribute)
enable_named_call() (in module flax.linen)
eps (flax.optim.AdaBelief attribute)
epsilon (flax.linen.BatchNorm attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.LayerNorm attribute)
F
feature_axes (flax.linen.LayerNorm attribute)
features (flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.Embed attribute)
filter() (flax.traverse_util.Traversal method)
flatten_dict() (in module flax.traverse_util)
flax.config
module
flax.core.variables
module
flax.errors
module
flax.jax_utils
module
flax.linen
module
,
[1]
flax.linen.transforms
module
flax.optim
module
flax.serialization
module
flax.struct
module
flax.traceback_util
module
flax.training.checkpoints
module
flax.training.lr_schedule
module
flax.traverse_util
module
flax_filter_frames (in module flax.config)
flax_profile (in module flax.config)
freeze() (in module flax.core.frozen_dict)
from_bytes() (in module flax.serialization)
from_state_dict() (in module flax.serialization)
FrozenDict (class in flax.core.frozen_dict)
G
gate_fn (flax.linen.GRUCell attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.OptimizedLSTMCell attribute)
GDACheckpointingRequiredError
GDARestoreTargetRequiredError
gelu() (in module flax.linen)
glu() (in module flax.linen)
GradientDescent (class in flax.optim)
group_size (flax.linen.GroupNorm attribute)
GroupNorm (class in flax.linen)
GRUCell (class in flax.linen)
H
hide_flax_in_tracebacks() (in module flax.traceback_util)
I
init() (flax.linen.Module method)
(in module flax.linen)
init_param_state() (flax.optim.OptimizerDef method)
init_with_output() (flax.linen.Module method)
(in module flax.linen)
InvalidCheckpointError
InvalidFilterError
InvalidRngError
InvalidScopeError
is_initializing() (flax.linen.Module method)
iterate() (flax.traverse_util.Traversal method)
(flax.traverse_util.TraverseAttr method)
(flax.traverse_util.TraverseCompose method)
(flax.traverse_util.TraverseEach method)
(flax.traverse_util.TraverseFilter method)
(flax.traverse_util.TraverseId method)
(flax.traverse_util.TraverseItem method)
(flax.traverse_util.TraverseMerge method)
(flax.traverse_util.TraverseTree method)
J
JaxTransformError
jit() (in module flax.linen)
jvp() (in module flax.linen)
K
kernel_dilation (flax.linen.ConvTranspose attribute)
kernel_init (flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.GRUCell attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
(flax.linen.OptimizedLSTMCell attribute)
kernel_size (flax.linen.ConvTranspose attribute)
L
LAMB (class in flax.optim)
LARS (class in flax.optim)
latest_checkpoint() (in module flax.training.checkpoints)
LayerNorm (class in flax.linen)
learning_rate (flax.optim.AdaBelief attribute)
log_sigmoid() (in module flax.linen)
log_softmax() (in module flax.linen)
LSTMCell (class in flax.linen)
M
make_attention_mask() (in module flax.linen)
make_causal_mask() (in module flax.linen)
make_rng() (flax.linen.Module method)
map_variables() (in module flax.linen)
mask (flax.linen.ConvTranspose attribute)
max_pool() (in module flax.linen)
merge() (flax.traverse_util.Traversal method)
min_delta (flax.training.early_stopping.EarlyStopping attribute)
ModelParamTraversal (class in flax.traverse_util)
ModifyScopeVariableError
module
flax.config
flax.core.variables
flax.errors
flax.jax_utils
flax.linen
,
[1]
flax.linen.transforms
flax.optim
flax.serialization
flax.struct
flax.traceback_util
flax.training.checkpoints
flax.training.lr_schedule
flax.traverse_util
Module (class in flax.linen)
Momentum (class in flax.optim)
momentum (flax.linen.BatchNorm attribute)
msgpack_restore() (in module flax.serialization)
msgpack_serialize() (in module flax.serialization)
MultiHeadDotProductAttention (class in flax.linen)
MultiOptimizer (class in flax.optim)
MultipleMethodsCompactError
N
NameInUseError
negative_slope_init (flax.linen.PReLU attribute)
nowrap() (in module flax.linen)
num_embeddings (flax.linen.Embed attribute)
num_groups (flax.linen.GroupNorm attribute)
num_heads (flax.linen.MultiHeadDotProductAttention attribute)
O
OptimizedLSTMCell (class in flax.linen)
Optimizer (class in flax.optim)
optimizer_def (flax.optim.Optimizer attribute)
OptimizerDef (class in flax.optim)
out_features (flax.linen.MultiHeadDotProductAttention attribute)
override_named_call() (in module flax.linen)
P
pad_shard_unpad() (in module flax.jax_utils)
padding (flax.linen.ConvTranspose attribute)
param() (flax.linen.Module method)
param_dtype (flax.linen.BatchNorm attribute)
(flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.Embed attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.GRUCell attribute)
(flax.linen.LayerNorm attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
(flax.linen.OptimizedLSTMCell attribute)
(flax.linen.PReLU attribute)
partial_eval_by_shape() (in module flax.jax_utils)
patience (flax.training.early_stopping.EarlyStopping attribute)
patience_count (flax.training.early_stopping.EarlyStopping attribute)
pmean() (in module flax.jax_utils)
pool() (in module flax.linen)
pop() (flax.core.frozen_dict.FrozenDict method)
precision (flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
prefetch_to_device() (in module flax.jax_utils)
PReLU (class in flax.linen)
pretty_repr() (flax.core.frozen_dict.FrozenDict method)
PyTreeNode (class in flax.struct)
Q
qkv_features (flax.linen.MultiHeadDotProductAttention attribute)
R
rate (flax.linen.Dropout attribute)
recurrent_kernel_init (flax.linen.GRUCell attribute)
(flax.linen.LSTMCell attribute)
(flax.linen.OptimizedLSTMCell attribute)
reduction_axes (flax.linen.LayerNorm attribute)
register_serialization_state() (in module flax.serialization)
relu (in module flax.linen)
remat() (in module flax.linen)
remat_scan() (in module flax.linen)
replicate() (in module flax.jax_utils)
ReservedModuleAttributeError
restore_checkpoint() (in module flax.training.checkpoints)
RMSProp (class in flax.optim)
S
save_checkpoint() (in module flax.training.checkpoints)
scale_init (flax.linen.BatchNorm attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.LayerNorm attribute)
scan() (in module flax.linen)
ScopeCollectionNotFound
ScopeParamNotFoundError
ScopeParamShapeError
ScopeVariableNotFoundError
SelfAttention (class in flax.linen)
Sequential (class in flax.linen)
set() (flax.traverse_util.Traversal method)
SetAttributeFrozenModuleError
SetAttributeInModuleSetupError
setup() (flax.linen.Module method)
should_stop (flax.training.early_stopping.EarlyStopping attribute)
show_flax_in_tracebacks() (in module flax.traceback_util)
sigmoid() (in module flax.linen)
soft_sign() (in module flax.linen)
softmax() (in module flax.linen)
softplus() (in module flax.linen)
sow() (flax.linen.Module method)
state (flax.optim.Optimizer attribute)
strides (flax.linen.ConvTranspose attribute)
swish() (in module flax.linen)
switch() (in module flax.linen)
T
tabulate() (flax.linen.Module method)
(in module flax.linen)
target (flax.optim.Optimizer attribute)
to_bytes() (in module flax.serialization)
to_state_dict() (in module flax.serialization)
TrainState (class in flax.training.train_state)
TransformedMethodReturnValueError
TransformTargetError
Traversal (class in flax.traverse_util)
TraverseAttr (class in flax.traverse_util)
TraverseCompose (class in flax.traverse_util)
TraverseEach (class in flax.traverse_util)
TraverseFilter (class in flax.traverse_util)
TraverseId (class in flax.traverse_util)
TraverseItem (class in flax.traverse_util)
TraverseMerge (class in flax.traverse_util)
TraverseTree (class in flax.traverse_util)
tree() (flax.traverse_util.Traversal method)
tree_flatten() (flax.core.frozen_dict.FrozenDict method)
U
unflatten_dict() (in module flax.traverse_util)
unfreeze() (flax.core.frozen_dict.FrozenDict method)
(in module flax.core.frozen_dict)
unreplicate() (in module flax.jax_utils)
update() (flax.training.early_stopping.EarlyStopping method)
(flax.traverse_util.Traversal method)
(flax.traverse_util.TraverseAttr method)
(flax.traverse_util.TraverseCompose method)
(flax.traverse_util.TraverseEach method)
(flax.traverse_util.TraverseFilter method)
(flax.traverse_util.TraverseId method)
(flax.traverse_util.TraverseItem method)
(flax.traverse_util.TraverseMerge method)
(flax.traverse_util.TraverseTree method)
update_hyper_params() (flax.optim.MultiOptimizer method)
(flax.optim.OptimizerDef method)
use_bias (flax.linen.BatchNorm attribute)
(flax.linen.ConvTranspose attribute)
(flax.linen.Dense attribute)
(flax.linen.DenseGeneral attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.LayerNorm attribute)
(flax.linen.MultiHeadDotProductAttention attribute)
use_running_average (flax.linen.BatchNorm attribute)
use_scale (flax.linen.BatchNorm attribute)
(flax.linen.GroupNorm attribute)
(flax.linen.LayerNorm attribute)
V
Variable (class in flax.core.variables)
variable() (flax.linen.Module method)
variables (flax.linen.Module property)
vjp() (in module flax.linen)
vmap() (in module flax.linen)
W
weight_decay (flax.optim.AdaBelief attribute)
WeightNorm (class in flax.optim)
while_loop() (in module flax.linen)