flax.linen package#

Linen is the Flax Module system. Read more about our design goals in the Linen README.

Module#

class flax.linen.Module[source]#

Base class for all neural network modules. Layers and models should subclass this class.

All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over __init__, you should instead override setup(), which is automatically called to initialize the module.

Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the setup() method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice because it allows you to use module instances as if they are functions:

from flax import linen as nn

class Module(nn.Module):
  features: Tuple[int, ...] = (16, 4)

  def setup(self):
    self.dense1 = Dense(self.features[0])
    self.dense2 = Dense(self.features[1])

  def __call__(self, x):
    return self.dense2(nn.relu(self.dense1(x)))

Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the compact() wrapper.

__setattr__(name, val)[source]#

Sets an attribute on this Module.

We overload setattr solely to support pythonic naming via assignment of submodules in the special setup() function:

self.submodule_name = MyModule(...)

We also support lists and other general pytrees, e.g.:

self.submodules = [MyModule0(..), MyModule1(..), ...]
Parameters
  • name – Attribute to set.

  • val – Value of the attribute.

apply(variables, *args, rngs=None, method=None, mutable=False, capture_intermediates=False, **kwargs)[source]#

Applies a module method to variables and returns output and modified variables.

Note that method should be set if one would like to call apply on a different class method than __call__. For instance, suppose a Transformer modules has a method called encode, then the following calls apply on that method:

model = Transformer()
encoded = model.apply({'params': params}, x, method=Transformer.encode)

If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:

encoded = model.apply({'params': params}, x, method=model.encode)

Note method can also be a function that is not defined in Transformer. In that case, the function should have at least one argument representing an instance of the Module class:

def other_fn(instance, ...):
  instance.some_module_attr(...)
  ...

model.apply({'params': params}, x, method=other_fn)
Parameters
  • variables – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments passed to the specified apply method.

  • rngs – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.

  • method – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the __call__ method of the module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the specified apply method.

Returns

If mutable is False, returns output. If any collections are mutable, returns (output, vars), where vars are is a dict of the modified collections.

bind(variables, *args, rngs=None, mutable=False)[source]#

Creates an interactive Module instance by binding variables and RNGs.

bind provides an “interactive” instance of a Module directly without transforming a function with apply. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.

Once the variables (and optionally RNGs) are bound to a Module it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. bind() should only be used for interactive experimentation, and in all other cases we strongly encourage users to use apply() instead.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = nn.Dense(3)
    self.decoder = nn.Dense(5)

  def __call__(self, x):
    return self.decoder(self.encoder(x))

x = jnp.ones((16, 9))
ae = AutoEncoder()
variables = ae.init(jax.random.PRNGKey(0), x)
model = ae.bind(variables)
z = model.encoder(x)
x_reconstructed = model.decoder(z)
Parameters
  • variables – A dictionary containing variables keyed by variable collections. See flax.core.variables for more details about variables.

  • *args – Named arguments (not used).

  • rngs – a dict of PRNGKeys to initialize the PRNG sequences.

  • mutable

    Can be bool, str, or list. Specifies which collections should be treated as mutable:

    bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

Returns

A copy of this instance with bound variables and RNGs.

init(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

Initializes a module method with variables and returns modified variables.

Jitting init initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:

jit_init = jax.jit(SomeModule(...).init)
jit_init(rng, jnp.ones(input_shape, jnp.float32))
Parameters
  • rngs – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Returns

The initialized variable dict.

init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)[source]#

Initializes a module method with variables and returns output and modified variables.

Parameters
  • rngs – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Returns

(output, vars)`, where vars are is a dict of the modified collections.

is_initializing()[source]#

Returns True if running under self.init(…) or nn.init(…)().

This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under module.init or nn.init. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.

make_rng(name)[source]#

Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to make_rng returns a new RNG key, while still guaranteeing full reproducibility.

TODO: Link to Flax RNG design note.

Parameters

name – The RNG sequence name.

Returns

The newly generated RNG key.

param(name, init_fn, *init_args)[source]#

Declares and returns a parameter in this Module.

Parameters are read-only variables in the collection named “params”. See flax.core.variables for more details on variables.

The first argument of init_fn is assumed to be a PRNG key, which is provided automatically and does not have to be passed using init_args:

mean = self.param('mean', lecun_normal(), (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, but only shape has to be provided explicitly; key is set automatically using the PRNG for params that is passed when initializing the module using init().

Parameters
  • name – The parameter name.

  • init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.

  • *init_args – The arguments to pass to init_fn.

Returns

The value of the initialized parameter.

setup()[source]#

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

sow(col, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

Stores a value in a collection.

Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.

If the target collection is not mutable sow behaves like a no-op and returns False.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
  @nn.compact
  def __call__(self, x):
    h = nn.Dense(4)(x)
    self.sow('intermediates', 'h', h)
    return nn.Dense(2)(h)

x = jnp.ones((16, 9))
model = Foo()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, x, mutable=['intermediates'])
print(state['intermediates'])  # {'h': (...,)}

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:

class Foo2(nn.Module):
  @nn.compact
  def __call__(self, x):
    init_fn = lambda: 0
    reduce_fn = lambda a, b: a + b
    self.sow('intermediates', 'h', x,
             init_fn=init_fn, reduce_fn=reduce_fn)
    self.sow('intermediates', 'h', x * 2,
             init_fn=init_fn, reduce_fn=reduce_fn)
    return x

model = Foo2()
variables = model.init(jax.random.PRNGKey(0), x)
y, state = model.apply(variables, jnp.ones((1, 1)), mutable=['intermediates'])
print(state['intermediates'])  # ==> {'h': [[3.]]}
Parameters
  • col – The name of the variable collection.

  • name – The name of the variable.

  • value – The value of the variable.

  • reduce_fn – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

Returns

True if the value has been stored successfully, False otherwise.

tabulate(rngs, *args, method=None, mutable=True, depth=None, exclude_methods=(), **kwargs)[source]#

Creates a summary of the Module represented as a table.

This method has the same signature as init, but instead of returning the variables, it returns the string summarizing the Module in a table. tabulate uses jax.eval_shape to run the forward computation without consuming any FLOPs or allocating memory.

Example:

import jax
import jax.numpy as jnp
import flax.linen as nn

class Foo(nn.Module):
    @nn.compact
    def __call__(self, x):
        h = nn.Dense(4)(x)
        return nn.Dense(2)(h)

x = jnp.ones((16, 9))

print(Foo().tabulate(jax.random.PRNGKey(0), x))

This gives the following output:

                   Foo Summary
┏━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path    ┃ outputs       ┃ params               ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ Inputs  │ float32[16,9] │                      │
├─────────┼───────────────┼──────────────────────┤
│ Dense_0 │ float32[16,4] │ bias: float32[4]     │
│         │               │ kernel: float32[9,4] │
│         │               │                      │
│         │               │ 40 (160 B)           │
├─────────┼───────────────┼──────────────────────┤
│ Dense_1 │ float32[16,2] │ bias: float32[2]     │
│         │               │ kernel: float32[4,2] │
│         │               │                      │
│         │               │ 10 (40 B)            │
├─────────┼───────────────┼──────────────────────┤
│ Foo     │ float32[16,2] │                      │
├─────────┼───────────────┼──────────────────────┤
│         │         Total │ 50 (200 B)           │
└─────────┴───────────────┴──────────────────────┘

          Total Parameters: 50 (200 B)

Note: rows order in the table does not represent execution order, instead it aligns with the order of keys in variables which are sorted alphabetically.

Parameters
  • rngs – The rngs for the variable collections.

  • *args – The arguments to the forward computation.

  • method – An optional method. If provided, applies this method. If not provided, applies the __call__ method.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except ‘intermediates’ are mutable.

  • depth – controls how many submodule deep the summary can go. By default its None which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.

  • exclude_methods – A sequence of strings that specifies which methods should be ignored. In case a module calls a helper method from its main method, use this argument to exclude the helper method from the summary to avoid ambiguity.

  • **kwargs – keyword arguments to pass to the forward computation.

Returns

A string summarizing the Module.

variable(col, name, init_fn=None, *init_args)[source]#

Declares and returns a variable in this Module.

See flax.core.variables for more information. See also param() for a shorthand way to define read-only variables in the “params” collection.

Contrary to param(), all arguments passing using init_fn should be passed on explicitly:

key = self.make_rng('stats')
mean = self.variable('stats', 'mean', lecun_normal(), key, (2, 2))

In the example above, the function lecun_normal expects two arguments: key and shape, and both have to be passed on. The PRNG for stats has to be provided explicitly when calling init() and apply().

Parameters
  • col – The variable collection name.

  • name – The variable name.

  • init_fn – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.

  • *init_args – The arguments to pass to init_fn.

Returns

A flax.core.variables.Variable that can be read or set via “.value” attribute. Throws an error if the variable exists already.

property variables#

Returns the variables in this module.

Init/Apply#

flax.linen.apply(fn, module, mutable=False, capture_intermediates=False)[source]#

Creates an apply function to call fn with a bound module.

Unlike Module.apply this function returns a new function with the signature (variables, *args, rngs=None, **kwargs) -> T where T is the return type of fn. If mutable is not False the return type is a tuple where the second item is a FrozenDict with the mutated variables.

The apply function that is returned can be directly composed with JAX transformations like jax.jit:

def f(foo, x):
  z = foo.encode(x)
  y = foo.decode(z)
  # ...
  return y

foo = Foo()
f_jitted = jax.jit(nn.apply(f, foo))
f_jitted(variables, x)
Parameters
  • fn – The function that should be applied. The first argument passed will be an module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The apply function wrapping fn.

flax.linen.init(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

Creates an init function to call fn with a bound module.

Unlike Module.init this function returns a new function with the signature (rngs, *args, **kwargs) -> variables. The rngs can be a dict of PRNGKeys or a single `PRNGKey which is equivalent to passing a dict with one PRNGKey with the name “params”.

The init function that is returned can be directly composed with JAX transformations like jax.jit:

def f(foo, x):
  z = foo.encode(x)
  y = foo.decode(z)
  # ...
  return y

foo = Foo()
f_jitted = jax.jit(nn.init(f, foo))
variables = f_jitted(rng, x)
Parameters
  • fn – The function that should be applied. The first argument passed will be an module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The init function wrapping fn.

flax.linen.init_with_output(fn, module, mutable=DenyList(deny='intermediates'), capture_intermediates=False)[source]#

Creates an init function to call fn with a bound module that also returns the function outputs.

Unlike Module.init_with_output this function returns a new function with the signature (rngs, *args, **kwargs) -> (T, variables) where T is the return type of fn. The rngs can be a dict of PRNGKeys or a single `PRNGKey which is equivalent to passing a dict with one PRNGKey with the name “params”.

The init function that is returned can be directly composed with JAX transformations like jax.jit:

def f(foo, x):
  z = foo.encode(x)
  y = foo.decode(z)
  # ...
  return y

foo = Foo()
f_jitted = jax.jit(nn.init_with_output(f, foo))
y, variables = f_jitted(rng, x)
Parameters
  • fn – The function that should be applied. The first argument passed will be an module instance of the module with variables and RNGs bound to it.

  • module – The Module that will be used to bind variables and RNGs to. The Module passed as the first argument to fn will be a clone of module.

  • mutable – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default all collections except “intermediates” are mutable.

  • capture_intermediates – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

Returns

The init function wrapping fn.

Variables#

A variable dict is a normal Python dictionary, which is a container for one or more “variable collections”, each of which are nested dictionaries whose leaves are jax.numpy arrays.

The different variable collections share the same nested tree structure.

For example, consider the following variable dictionary:

{
  "params": {
    "Conv1": { "weight": ..., "bias": ... },
    "BatchNorm1": { "scale": ..., "mean": ... },
    "Conv2": {...}
  },
  "batch_stats": {
    "BatchNorm1": { "moving_mean": ..., "moving_average": ...}
  }
}

In this case, the "BatchNorm1" key lives in both the "params" and `"batch_stats"" collections. This reflects the fact that the submodule named ""BatchNorm1"" has both trainable parameters (the "params" collection), as well as other non-trainable variables (the "batch_stats" collection)

TODO: Make “variable dict” design note, and link to it from here.

class flax.core.variables.Variable(scope, collection, name)[source]#

A Variable object allows mutable access to a variable in a VariableDict.

Variables are identified by a collection (e.g., “batch_stats”) and a name (e.g., “moving_mean”). The value property gives access to the variable’s content and can be assigned to for mutation.

Compact methods#

flax.linen.compact(fun)[source]#

Marks the given module method allowing inlined submodules.

Methods wrapped in @compact can define submodules directly within the method.

For instance:

@compact
__call__(self, x, features):
  x = nn.Dense(features)(x)
  ...

At most one method in each Module may be wrapped with @compact.

Parameters

fun – The Module method to mark as compact.

Returns

The given function fun marked as compact.

No wrap methods#

flax.linen.nowrap(fun)[source]#

Marks the given module method as a helper method that needn’t be wrapped.

Methods wrapped in @nowrap are private helper methods that needn’t be wrapped with the state handler or a separate named_call transform.

This is needed in several concrete instances:
  • if you have a helper method that returns Modules or Variables to prevent it from being functionalized by named_call. (Functionalized methods can’t return Modules/Variables.)

  • if you’re subclassing a method like Module.param and don’t want this overriden core function decorated with the state management wrapper.

  • If you want a method to be callable from an unbound Module (e.g.: a function of construction of arguments that doesn’t depend on params/RNGs)

For instance:

@nowrap
def _make_dense(self, num_features):
  return nn.Dense(num_features)

@compact
def __call__(self, x):
  # now safe to use constructor helper even if using named_call
  dense = self._make_dense(self.num_features)
  return dense(x)
Parameters

fun – The Module method to mark as nowrap.

Returns

The given function fun marked as nowrap.

Profiling#

The Flax Module system.

enable_named_call()

Enables named call wrapping for labelling profile traces.

disable_named_call()

Disables named call wrapping.

override_named_call([enable])

Returns a context manager that enables/disables named call wrapping.

Inspection#

The Flax Module system.

tabulate(module, rngs[, method, mutable, ...])

Returns a function that creates a summary of the Module represented as a table.

Transformations#

JAX transformations on Modules.

Jax functional transformations operate on pure functions. Flax extends these transformations to also operate on Module’s which have stateful variables and PRNG sequences. We refer to these extended versions as “lifted transformations”.

A lifted transformation can be applied to a Module class or a function that takes a Module instance as its first argument.

vmap(target[, variable_axes, split_rngs, ...])

A lifted version of jax.vmap.

scan(target[, variable_axes, ...])

A lifted version of jax.lax.scan.

jit(target[, variables, rngs, ...])

Lifted version of jax.jit.

remat(target[, variables, rngs, concrete, ...])

Lifted version of jax.checkpoint.

remat_scan(target[, lengths, policy, ...])

Combines remat and scan for memory efficiency and constant time compilation.

map_variables(target[, mapped_collections, ...])

Map Variables inside a module.

jvp(fn, mdl, primals, tangents, ...[, ...])

A lifted version of jax.jvp.

vjp(fn, mdl, *primals[, has_aux, ...])

A lifted version of jax.vjp.

custom_vjp(fn, forward_fn, backward_fn[, ...])

Lifted version of jax.custom_vjp.

while_loop(cond_fn, body_fn, mdl, init[, ...])

Lifted version of jax.lax.while_loop.

cond(pred, true_fun, false_fun, mdl, *operands)

Lifted version of jax.lax.cond.

switch(index, branches, mdl, *operands[, ...])

Lifted version of jax.lax.switch.

Linear modules#

Dense(features[, use_bias, dtype, ...])

A linear transformation applied over the last dimension of the input.

DenseGeneral(features[, axis, batch_dims, ...])

A linear transformation with flexible axes.

Conv(features, kernel_size[, strides, ...])

Convolution Module wrapping lax.conv_general_dilated.

ConvTranspose(features, kernel_size[, ...])

Convolution Module wrapping lax.conv_transpose.

ConvLocal(features, kernel_size[, strides, ...])

Local convolution Module wrapping lax.conv_general_dilated_local.

Embed(num_embeddings, features[, dtype, ...])

Embedding Module.

Normalization#

BatchNorm([use_running_average, axis, ...])

BatchNorm Module.

LayerNorm([epsilon, dtype, param_dtype, ...])

Layer normalization (https://arxiv.org/abs/1607.06450).

GroupNorm([num_groups, group_size, epsilon, ...])

Group normalization (arxiv.org/abs/1803.08494).

Pooling#

max_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the maximum of a window slice.

avg_pool(inputs, window_shape[, strides, ...])

Pools the input by taking the average over a window.

pool(inputs, init, reduce_fn, window_shape, ...)

Helper function to define pooling functions.

Activation functions#

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

elu(x[, alpha])

Exponential linear unit activation function.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

log_sigmoid(x)

Log-sigmoid activation function.

log_softmax(x[, axis, where, initial])

Log-Softmax function.

relu

Rectified linear unit activation function.

sigmoid(x)

Sigmoid activation function.

soft_sign(x)

Soft-sign activation function.

softmax(x[, axis, where, initial])

Softmax function.

softplus(x)

Softplus activation function.

swish(x)

SiLU activation function.

PReLU([param_dtype, negative_slope_init, ...])

Parametric Rectified Linear Unit (PReLU) activation function.

Combinators#

Sequential(layers[, parent, name])

Applies a linear chain of Modules.

Attention primitives#

dot_product_attention_weights(query, key[, ...])

Computes dot-product attention weights given query and key.

dot_product_attention(query, key, value[, ...])

Computes dot-product attention given query, key, and value.

make_attention_mask(query_input, key_input)

Mask-making helper for attention weights.

make_causal_mask(x[, extra_batch_dims, dtype])

Make a causal mask for self-attention.

SelfAttention(num_heads[, dtype, ...])

Self-attention special case of multi-head dot-product attention.

MultiHeadDotProductAttention(num_heads[, ...])

Multi-head dot-product attention.

Stochastic#

Dropout(rate[, broadcast_dims, ...])

Create a dropout layer.

RNN primitives#

LSTMCell([gate_fn, activation_fn, ...])

LSTM cell.

OptimizedLSTMCell([gate_fn, activation_fn, ...])

More efficient LSTM Cell that concatenates state components before matmul.

GRUCell([gate_fn, activation_fn, ...])

GRU cell.