Extracting intermediate values#
This pattern will show you how to extract intermediate values from a module.
Let’s start with this simple CNN that uses nn.compact
.
from flax import linen as nn
import jax
import jax.numpy as jnp
from typing import Sequence
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
Because this module uses nn.compact
, we don’t have direct access to
intermediate values. There are a few ways to expose them:
Store intermediate values in a new variable collection#
The CNN can be augmented with calls to sow
to store intermediates as following:
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
class SowCNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
self.sow('intermediates', 'features', x)
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
sow
acts as a no-op when the variable collection is not mutable.
Therefore, it works perfectly for debugging and optional tracking of intermediates.
The ‘intermediates’ collection is also used by the capture_intermediates
API (see final section).
Note that, by default sow
appends values every time it is called:
This is necessary because once instantiated, a module could be called multiple times in its parent module, and we want to catch all the sowed values.
So you want to make sure that you do not feed intermediate values back in in
variables
. Otherwise every call will increase the length of that tuple and trigger a recompile.To override the default append behavior, specify
init_fn
andreduce_fn
- seeModule.sow()
.
class SowCNN2(nn.Module):
@nn.compact
def __call__(self, x):
mod = SowCNN(name='SowCNN')
return mod(x) + mod(x) # Calling same module instance twice.
@jax.jit
def init(key, x):
variables = SowCNN2().init(key, x)
# By default the 'intermediates' collection is not mutable during init.
# So variables will only contain 'params' here.
return variables
@jax.jit
def predict(variables, x):
# If mutable='intermediates' is not specified, then .sow() acts as a noop.
output, mod_vars = SowCNN2().apply(variables, x, mutable='intermediates')
features = mod_vars['intermediates']['SowCNN']['features']
return output, features
batch = jnp.ones((1,28,28,1))
variables = init(jax.random.PRNGKey(0), batch)
preds, feats = predict(variables, batch)
assert len(feats) == 2 # Tuple with two values since module was called twice.
Refactor module into submodules#
This is a useful pattern for cases where it’s clear in what particular
way you want to split your submodules. Any submodule you expose in setup
can
be used directly. In the limit, you can define all submodules in setup
and
avoid using nn.compact
altogether.
class RefactoredCNN(nn.Module):
def setup(self):
self.features = Features()
self.classifier = Classifier()
def __call__(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class Features(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
return x
class Classifier(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
@jax.jit
def init(key, x):
variables = RefactoredCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return RefactoredCNN().apply({"params": params}, x,
method=lambda module, x: module.features(x))
params = init(jax.random.PRNGKey(0), batch)
features(params, batch)
Use capture_intermediates
#
Linen supports the capture of intermediate return values from submodules automatically without any code changes. This pattern should be considered the “sledge hammer” approach to capturing intermediates. As a debugging and inspection tool it is very useful but using the other patterns described in this howto.
In the following code example we check if any intermediate activations are non-finite (NaN or infinite):
@jax.jit
def init(key, x):
variables = CNN().init(key, x)
return variables
@jax.jit
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin
variables = init(jax.random.PRNGKey(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non finite intermediate detected!"
By default only the intermediates of __call__
methods are collected.
Alternatively, you can pass a custom filter based on the Module
instance and the method name.
filter_Dense = lambda mdl, method_name: isinstance(mdl, nn.Dense)
filter_encodings = lambda mdl, method_name: method_name == "encode"
y, state = CNN().apply(variables, batch, capture_intermediates=filter_Dense, mutable=["intermediates"])
dense_intermediates = state['intermediates']
Use Sequential
#
You could also define CNN
using a simple implementation of a Sequential
combinator (this is quite common in more stateful approaches). This may be useful
for very simple models and gives you arbitrary model
surgery. But it can be very limiting – if you even want to add one conditional, you are
forced to refactor away from Sequential
and structure
your model more explicitly.
class Sequential(nn.Module):
layers: Sequence[nn.Module]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
def SeqCNN():
return Sequential([
nn.Conv(features=32, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
nn.Conv(features=64, kernel_size=(3, 3)),
nn.relu,
lambda x: nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)),
lambda x: x.reshape((x.shape[0], -1)), # flatten
nn.Dense(features=256),
nn.relu,
nn.Dense(features=10),
nn.log_softmax,
])
@jax.jit
def init(key, x):
variables = SeqCNN().init(key, x)
return variables['params']
@jax.jit
def features(params, x):
return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x)
params = init(jax.random.PRNGKey(0), batch)
features(params, batch)
Extracting gradients of intermediate values#
For debugging purposes, it can be useful to extract the gradients of intermediate values.
This can be done by using the Module.perturb()
method over the desired values.
class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.relu(nn.Dense(8)(x))
x = self.perturb('hidden', x)
x = nn.Dense(2)(x)
x = self.perturb('logits', x)
return x
perturb
adds a variable to a perturbations
collection by default,
it behaves like an identity function and the gradient of the perturbation
matches the gradient of the input. To get the perturbations just initialize
the model:
x = jnp.empty((1, 4)) # random data
y = jnp.empty((1, 2)) # random data
model = Model()
variables = model.init(jax.random.PRNGKey(1), x)
params, perturbations = variables['params'], variables['perturbations']
Finally compute the gradients of the loss with respect to the perturbations, these will match the gradients of the intermediates:
def loss_fn(params, perturbations, x, y):
y_pred = model.apply({'params': params, 'perturbations': perturbations}, x)
return jnp.mean((y_pred - y) ** 2)
intermediate_grads = jax.grad(loss_fn, argnums=1)(params, perturbations, x, y)