Metrics#
- class flax.experimental.nnx.metrics.Metric(*args, **kwargs)#
- class flax.experimental.nnx.metrics.Average(*args, **kwargs)#
- class flax.experimental.nnx.metrics.Accuracy(*args, **kwargs)#
- class flax.experimental.nnx.metrics.MultiMetric(*args, **kwargs)#
MultiMetric class to store multiple metrics and update them in a single call.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax.experimental import nnx ... >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) ... >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) ... >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ... ) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}