Warning

This package is deprecated. See flax.linen for our new module API.

flax.nn package (deprecated)

Core: Module abstraction

class flax.nn.Module(*args, **kwargs)[source]
get_param(name)[source]

Retrieves a parameter within the module’s apply function.

Parameters

name – The name of the parameter.

Returns

The value of the parameter.

classmethod init(_rng, *args, name=None, **kwargs)[source]

Initializes the module parameters.

Parameters
  • _rng – the random number generator used to initialize parameters.

  • *args – arguments passed to the module’s apply function

  • name – name of this module.

  • **kwargs – keyword arguments passed to the module’s apply function

Returns

A pair consisting of the model output and the initialized parameters

classmethod init_by_shape(_rng, input_specs, *args, name=None, **kwargs)[source]

Initialize the module parameters.

This method will initialize the module parameters without computation. Initializer functions can depend on the shape but not the value of inputs.

Example:

input_shape = (batch_size, image_size, image_size, 3)
model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
                                input_specs=[(input_shape, jnp.float32)])
Parameters
  • _rng – the random number generator used to initialize parameters.

  • input_specs – an iterable of (shape, dtype) pairs specifying the inputs

  • *args – arguments passed to the module’s apply function

  • name – name of this module.

  • **kwargs – keyword arguments passed to the module’s apply function

Returns

A pair consisting of the model output and the initialized parameters

param(name, shape, initializer)[source]

Defines a parameter within the module’s apply function.

Parameters
  • name – The name of the parameter.

  • shape – The shape of the parameter. If None the param be any type.

  • initializer – An initializer function taking an RNG and the shape as arguments.

Returns

The value of the parameter.

classmethod partial(*, name=None, **kwargs)[source]

Partially applies a module with the given arguments.

Unlike functools.partial this will return a subclass of Module.

Parameters
  • name – the name used the module

  • **kwargs – the argument to be applied.

Returns

A subclass of Module which partially applies the given keyword arguments.

classmethod shared(*, name=None, **kwargs)[source]

Partially applies a module and shared parameters for each call.

Parameters
  • name – name of this module.

  • **kwargs – keyword arguments that should be partially applied.

Returns

A subclass of Module that shares parameters when called multiple times.

state(name, shape=None, initializer=None, collection=None)[source]

Declare a state variable within the module’s apply function.

A state variable has an attribute value which can be updated by simply assigning a value to it. For example:

class Example(nn.Module):
  def apply(self, inputs, decay=0.9):
    ema = self.state('ema', inputs.shape, initializers.zeros)
    ema.value = decay * ema.value + (1 - decay) * inputs
    return inputs

By default, Modules are stateless. See flax.nn.stateful to enable stateful computations.

Parameters
  • name – the name of the state variable.

  • shape – optional shape passed to the initializer (default: None)

  • initializer – optional initializer function taking an RNG and the shape as arguments.

  • collection – optional flax.nn.Collection used to store the state. By default the state collection passed to the nn.stateful context is used.

Returns

An instance of ModuleState.

Core: Additional

module(fun)

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

Model(module, params)

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

Collection([state])

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

capture_module_outputs()

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

stateful([state, mutable])

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

get_state()

module_method(fn)

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

Linear modules

Dense(inputs, features[, bias, dtype, …])

Applies a linear transformation to the inputs along the last dimension.

DenseGeneral(inputs, features[, axis, …])

Applies a linear transformation to the inputs along multiple dimensions.

Conv(inputs, features, kernel_size[, …])

Applies a convolution to the inputs.

Embed(inputs, num_embeddings, features[, …])

Embeds the inputs along the last dimension.

Normalization

BatchNorm(x[, batch_stats, …])

Normalizes the input using batch statistics.

LayerNorm(x[, epsilon, dtype, bias, scale, …])

Applies layer normalization on the input.

GroupNorm(x[, num_groups, group_size, …])

Applies group normalization to the input (arxiv.org/abs/1803.08494).

Pooling

max_pool(inputs, window_shape[, strides, …])

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

avg_pool(inputs, window_shape[, strides, …])

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

Activation functions

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis])

Log-Softmax function.

relu

Rectified linear unit activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis])

Softmax function.

softplus(x)

Softplus activation function.

swish(x)

SiLU activation function.

Stochastic functions

make_rng()

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

stochastic(rng)

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

is_stochastic()

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

dropout(inputs, rate[, deterministic, rng])

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead.

Attention primitives

dot_product_attention(query, key, value[, …])

DEPRECATION WARNING: “The flax.nn module is Deprecated, use flax.linen instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/main/flax/linen/README.md” Computes dot-product attention given query, key, and value.

SelfAttention

alias of flax.nn.base.MultiHeadDotProductAttention

RNN primitives

LSTMCell(carry, inputs[, gate_fn, …])

A long short-term memory (LSTM) cell.

OptimizedLSTMCell(carry, inputs[, gate_fn, …])

A long short-term memory (LSTM) cell.

GRUCell(carry, inputs[, gate_fn, …])

Gated recurrent unit (GRU) cell.