Metrics#

class flax.nnx.metrics.Metric(*args, **kwargs)#

Base class for metrics. Any class that subclasses Metric should implement a compute, reset and update method.

__init__()#
compute()#

Compute and return the value of the Metric.

reset()#

In-place reset the Metric.

update(**kwargs)#

In-place update the Metric.

class flax.nnx.metrics.Average(*args, **kwargs)#

Average metric.

Example usage:

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Average()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
__init__(argname='values')#

Pass in a string denoting the key-word argument that update() will use to derive the new value. For example, constructing the metric as avg = Average('test') would allow you to make updates with avg.update(test=new_value).

Parameters

argname – an optional string denoting the key-word argument that update() will use to derive the new value. Defaults to 'values'.

compute()#

Compute and return the average.

reset()#

Reset this Metric.

update(**kwargs)#

In-place update this Metric. This method will use the value from kwargs[self.argname] to update the metric, where self.argname is defined on construction.

Parameters

**kwargs – the key-word arguments that contains a self.argname entry that maps to the value we want to use to update this metric.

class flax.nnx.metrics.Accuracy(*args, **kwargs)#

Accuracy metric. This metric subclasses Average, and so they share the same reset and compute method implementations. Unlike Average, no string needs to be passed to Accuracy during construction.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> metrics = nnx.metrics.Accuracy()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.7, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
update(*, logits, labels, **_)#

In-place update this Metric.

Parameters
  • logits – the outputted predicted activations. These values are argmax-ed (on the trailing dimension), before comparing them to the labels.

  • labels – the ground truth integer labels.

class flax.nnx.metrics.Welford(*args, **kwargs)#

Uses Welford’s algorithm to compute the mean and variance of a stream of data.

Example usage:

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Welford()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
__init__(argname='values')#

Pass in a string denoting the key-word argument that update() will use to derive the new value. For example, constructing the metric as wf = Welford('test') would allow you to make updates with wf.update(test=new_value).

Parameters

argname – an optional string denoting the key-word argument that update() will use to derive the new value. Defaults to 'values'.

compute()#

Compute and return the mean and variance statistics in a Statistics dataclass object.

reset()#

Reset this Metric.

update(**kwargs)#

In-place update this Metric. This method will use the value from kwargs[self.argname] to update the metric, where self.argname is defined on construction.

Parameters

**kwargs – the key-word arguments that contains a self.argname entry that maps to the value we want to use to update this metric.

class flax.nnx.metrics.MultiMetric(*args, **kwargs)#

MultiMetric class to store multiple metrics and update them in a single call.

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> metrics = nnx.MultiMetric(
...   accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )

>>> metrics
MultiMetric(
  accuracy=Accuracy(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  ),
  loss=Average(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  )
)

>>> metrics.accuracy
Accuracy(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> metrics.loss
Average(
  argname='values',
  total=MetricState(
    value=Array(0., dtype=float32)
  ),
  count=MetricState(
    value=Array(0, dtype=int32)
  )
)

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
>>> metrics.compute()
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
__init__(**metrics)#

Pass in key-word arguments to the constructor, e.g. MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...).

Parameters

**metrics – the key-word arguments that will be used to access the corresponding Metric.

compute()#

Compute and return the value of all underlying Metric’s. This method will return a dictionary, mapping strings (defined by the key-word arguments **metrics passed to the constructor) to the corresponding metric value.

reset()#

Reset all underlying Metric’s.

update(**updates)#

In-place update all underlying Metric’s in this MultiMetric. All **updates will be passed to the update method of all underlying Metric’s.

Parameters

**updates – the key-word arguments that will be passed to the underlying Metric’s update method.