flax.nn.module

flax.nn.module(fun)[source]

DEPRECATION WARNING: The flax.nn module is Deprecated, use flax.linen instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md” Convert a function into the apply method of a new Module.

This is convenient shortcut for writing higher level modules that don’t need access to self for creating parameters directly.

Example usage:

@nn.module
def DenseLayer(x, features):
  x = flax.nn.Dense(x, features)
  x = flax.nn.relu(x)
  return x

This is exactly equivalent to defining the following nn.Module subclass:

class DenseLayer(nn.Module):
  def apply(self, x, features):
    x = flax.nn.Dense(x, features)
    x = flax.nn.relu(x)
    return x
Parameters

fun – the function to convert.

Returns

New Module subclass.