Metrics#
- class flax.nnx.metrics.Metric(*args, **kwargs)#
Base class for metrics. Any class that subclasses
Metric
should implement acompute
,reset
andupdate
method.- __init__()#
- compute()#
Compute and return the value of the
Metric
.
- reset()#
In-place reset the
Metric
.
- update(**kwargs)#
In-place update the
Metric
.
- class flax.nnx.metrics.Average(*args, **kwargs)#
Average metric.
Example usage:
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Average() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(values=batch_loss) >>> metrics.compute() Array(2.5, dtype=float32) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Array(2., dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- __init__(argname='values')#
Pass in a string denoting the key-word argument that
update()
will use to derive the new value. For example, constructing the metric asavg = Average('test')
would allow you to make updates withavg.update(test=new_value)
.- Parameters
argname – an optional string denoting the key-word argument that
update()
will use to derive the new value. Defaults to'values'
.
- compute()#
Compute and return the average.
- reset()#
Reset this
Metric
.
- update(**kwargs)#
In-place update this
Metric
. This method will use the value fromkwargs[self.argname]
to update the metric, whereself.argname
is defined on construction.- Parameters
**kwargs – the key-word arguments that contains a
self.argname
entry that maps to the value we want to use to update this metric.
- class flax.nnx.metrics.Accuracy(*args, **kwargs)#
Accuracy metric. This metric subclasses
Average
, and so they share the samereset
andcompute
method implementations. UnlikeAverage
, no string needs to be passed toAccuracy
during construction.Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> metrics = nnx.metrics.Accuracy() >>> metrics.compute() Array(nan, dtype=float32) >>> metrics.update(logits=logits, labels=labels) >>> metrics.compute() Array(0.6, dtype=float32) >>> metrics.update(logits=logits2, labels=labels2) >>> metrics.compute() Array(0.7, dtype=float32) >>> metrics.reset() >>> metrics.compute() Array(nan, dtype=float32)
- update(*, logits, labels, **_)#
In-place update this
Metric
.- Parameters
logits – the outputted predicted activations. These values are argmax-ed (on the trailing dimension), before comparing them to the labels.
labels – the ground truth integer labels.
- class flax.nnx.metrics.Welford(*args, **kwargs)#
Uses Welford’s algorithm to compute the mean and variance of a stream of data.
Example usage:
>>> import jax.numpy as jnp >>> from flax import nnx >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics = nnx.metrics.Welford() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32)) >>> metrics.update(values=batch_loss) >>> metrics.compute() Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32)) >>> metrics.update(values=batch_loss2) >>> metrics.compute() Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32)) >>> metrics.reset() >>> metrics.compute() Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
- __init__(argname='values')#
Pass in a string denoting the key-word argument that
update()
will use to derive the new value. For example, constructing the metric aswf = Welford('test')
would allow you to make updates withwf.update(test=new_value)
.- Parameters
argname – an optional string denoting the key-word argument that
update()
will use to derive the new value. Defaults to'values'
.
- compute()#
Compute and return the mean and variance statistics in a
Statistics
dataclass object.
- reset()#
Reset this
Metric
.
- update(**kwargs)#
In-place update this
Metric
. This method will use the value fromkwargs[self.argname]
to update the metric, whereself.argname
is defined on construction.- Parameters
**kwargs – the key-word arguments that contains a
self.argname
entry that maps to the value we want to use to update this metric.
- class flax.nnx.metrics.MultiMetric(*args, **kwargs)#
MultiMetric class to store multiple metrics and update them in a single call.
Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() ... ) >>> metrics MultiMetric( accuracy=Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ), loss=Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) ) >>> metrics.accuracy Accuracy( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> metrics.loss Average( argname='values', total=MetricState( value=Array(0., dtype=float32) ), count=MetricState( value=Array(0, dtype=int32) ) ) >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) >>> metrics.compute() {'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)} >>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2) >>> metrics.compute() {'accuracy': Array(0.7, dtype=float32), 'loss': Array(2., dtype=float32)} >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
- __init__(**metrics)#
Pass in key-word arguments to the constructor, e.g.
MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)
.- Parameters
**metrics – the key-word arguments that will be used to access the corresponding
Metric
.
- compute()#
Compute and return the value of all underlying
Metric
’s. This method will return a dictionary, mapping strings (defined by the key-word arguments**metrics
passed to the constructor) to the corresponding metric value.
- reset()#
Reset all underlying
Metric
’s.
- update(**updates)#
In-place update all underlying
Metric
’s in thisMultiMetric
. All**updates
will be passed to theupdate
method of all underlyingMetric
’s.- Parameters
**updates – the key-word arguments that will be passed to the underlying
Metric
’supdate
method.