API Reference#
- flax.config package
- flax.core.frozen_dict package
- flax.cursor package
- flax.error package
AlreadyExistsError
ApplyModuleInvalidMethodError
ApplyScopeInvalidVariablesStructureError
ApplyScopeInvalidVariablesTypeError
AssignSubModuleError
CallCompactUnboundModuleError
CallSetupUnboundModuleError
CallUnbindOnUnboundModuleError
CursorFindError
DescriptorAttributeError
IncorrectPostInitOverrideError
InvalidCheckpointError
InvalidFilterError
InvalidInstanceModuleError
InvalidRngError
InvalidScopeError
JaxTransformError
LazyInitError
MPACheckpointingRequiredError
MPARestoreDataCorruptedError
MPARestoreTargetRequiredError
ModifyScopeVariableError
MultipleMethodsCompactError
NameInUseError
PartitioningUnspecifiedError
ReservedModuleAttributeError
ScopeCollectionNotFound
ScopeParamNotFoundError
ScopeParamShapeError
ScopeVariableNotFoundError
SetAttributeFrozenModuleError
SetAttributeInModuleSetupError
TransformTargetError
TransformedMethodReturnValueError
TraverseTreeError
- flax.jax_utils package
- flax.linen
- Module
- Init/Apply
- Layers
- Linear Modules
- Pooling
- Normalization
- Combinators
- Stochastic
- Attention
- Recurrent
RNNCellBase
LSTMCell
OptimizedLSTMCell
GRUCell
RNN
Bidirectional
- flax.linen.Dense
- flax.linen.DenseGeneral
- flax.linen.Conv
- flax.linen.ConvTranspose
- flax.linen.ConvLocal
- flax.linen.Embed
- flax.linen.BatchNorm
- flax.linen.LayerNorm
- flax.linen.GroupNorm
- flax.linen.RMSNorm
- flax.linen.Sequential
- flax.linen.Dropout
- flax.linen.SelfAttention
- flax.linen.MultiHeadDotProductAttention
- flax.linen.RNNCellBase
- flax.linen.LSTMCell
- flax.linen.OptimizedLSTMCell
- flax.linen.GRUCell
- flax.linen.RNN
- flax.linen.Bidirectional
- flax.linen.max_pool
- flax.linen.avg_pool
- flax.linen.pool
- flax.linen.dot_product_attention_weights
- flax.linen.dot_product_attention
- flax.linen.make_attention_mask
- flax.linen.make_causal_mask
- Activation functions
PReLU
celu()
elu()
gelu()
glu()
hard_sigmoid()
hard_silu()
hard_swish()
hard_tanh()
leaky_relu()
log_sigmoid()
log_softmax()
logsumexp()
one_hot()
relu()
selu()
sigmoid()
silu()
soft_sign()
softmax()
softplus()
standardize()
swish()
tanh()
- flax.linen.activation.PReLU
- flax.linen.activation.celu
- flax.linen.activation.elu
- flax.linen.activation.gelu
- flax.linen.activation.glu
- flax.linen.activation.hard_sigmoid
- flax.linen.activation.hard_silu
- flax.linen.activation.hard_swish
- flax.linen.activation.hard_tanh
- flax.linen.activation.leaky_relu
- flax.linen.activation.log_sigmoid
- flax.linen.activation.log_softmax
- flax.linen.activation.logsumexp
- flax.linen.activation.one_hot
- flax.linen.activation.relu
- flax.linen.activation.relu6
- flax.linen.activation.selu
- flax.linen.activation.sigmoid
- flax.linen.activation.silu
- flax.linen.activation.soft_sign
- flax.linen.activation.softmax
- flax.linen.activation.softplus
- flax.linen.activation.standardize
- flax.linen.activation.swish
- flax.linen.activation.tanh
- Initializers
constant()
delta_orthogonal()
glorot_normal()
glorot_uniform()
he_normal()
he_uniform()
kaiming_normal()
kaiming_uniform()
lecun_normal()
lecun_uniform()
normal()
ones()
ones_init()
orthogonal()
uniform()
variance_scaling()
xavier_normal()
xavier_uniform()
zeros()
zeros_init()
- flax.linen.initializers.constant
- flax.linen.initializers.delta_orthogonal
- flax.linen.initializers.glorot_normal
- flax.linen.initializers.glorot_uniform
- flax.linen.initializers.he_normal
- flax.linen.initializers.he_uniform
- flax.linen.initializers.kaiming_normal
- flax.linen.initializers.kaiming_uniform
- flax.linen.initializers.lecun_normal
- flax.linen.initializers.lecun_uniform
- flax.linen.initializers.normal
- flax.linen.initializers.ones
- flax.linen.initializers.ones_init
- flax.linen.initializers.orthogonal
- flax.linen.initializers.uniform
- flax.linen.initializers.variance_scaling
- flax.linen.initializers.xavier_normal
- flax.linen.initializers.xavier_uniform
- flax.linen.initializers.zeros
- flax.linen.initializers.zeros_init
- Transformations
vmap()
scan()
jit()
remat()
remat_scan()
map_variables()
jvp()
vjp()
custom_vjp()
while_loop()
cond()
switch()
- flax.linen.vmap
- flax.linen.scan
- flax.linen.jit
- flax.linen.remat
- flax.linen.remat_scan
- flax.linen.map_variables
- flax.linen.jvp
- flax.linen.vjp
- flax.linen.custom_vjp
- flax.linen.while_loop
- flax.linen.cond
- flax.linen.switch
- Inspection
- Variable dictionary
- SPMD
Partitioned()
with_partitioning()
get_partition_spec()
get_sharding()
LogicallyPartitioned()
logical_axis_rules()
set_logical_axis_rules()
get_logical_axis_rules()
logical_to_mesh_axes()
logical_to_mesh()
logical_to_mesh_sharding()
with_logical_constraint()
with_logical_partitioning()
- flax.linen.Partitioned
- flax.linen.with_partitioning
- flax.linen.get_partition_spec
- flax.linen.get_sharding
- flax.linen.LogicallyPartitioned
- flax.linen.logical_axis_rules
- flax.linen.set_logical_axis_rules
- flax.linen.get_logical_axis_rules
- flax.linen.logical_to_mesh_axes
- flax.linen.logical_to_mesh
- flax.linen.logical_to_mesh_sharding
- flax.linen.with_logical_constraint
- flax.linen.with_logical_partitioning
- Decorators
- Profiling
- flax.serialization package
- flax.struct package
- flax.traceback_util package
- flax.training package
- flax.traverse_util package