Metrics

Metrics#

class flax.experimental.nnx.metrics.Metric(*args, **kwargs)#
class flax.experimental.nnx.metrics.Average(*args, **kwargs)#
class flax.experimental.nnx.metrics.Accuracy(*args, **kwargs)#
class flax.experimental.nnx.metrics.MultiMetric(*args, **kwargs)#

MultiMetric class to store multiple metrics and update them in a single call.

Example usage:

>>> import jax, jax.numpy as jnp
>>> from flax.experimental import nnx

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.MultiMetric(
...   accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
>>> )
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
>>> metrics.compute()
{'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}