"""Early stopping."""

import math
from flax import struct

[docs]class EarlyStopping(struct.PyTreeNode): """Early stopping to avoid overfitting during training. The following example stops training early if the difference between losses recorded in the current epoch and previous epoch is less than 1e-3 consecutively for 2 times:: early_stop = EarlyStopping(min_delta=1e-3, patience=2) for epoch in range(1, num_epochs+1): rng, input_rng = jax.random.split(rng) optimizer, train_metrics = train_epoch( optimizer, train_ds, config.batch_size, epoch, input_rng) _, early_stop = early_stop.update(train_metrics['loss']) if early_stop.should_stop: print('Met early stopping criteria, breaking...') break Attributes: min_delta: Minimum delta between updates to be considered an improvement. patience: Number of steps of no improvement before stopping. best_metric: Current best metric value. patience_count: Number of steps since last improving update. should_stop: Whether the training loop should stop to avoid overfitting. """ min_delta: float = 0 patience: int = 0 best_metric: float = float('inf') patience_count: int = 0 should_stop: bool = False def reset(self): return self.replace(best_metric=float('inf'), patience_count=0, should_stop=False)
[docs] def update(self, metric): """Update the state based on metric. Returns: A pair (has_improved, early_stop), where `has_improved` is True when there was an improvement greater than `min_delta` from the previous `best_metric` and `early_stop` is the updated `EarlyStop` object. """ if math.isinf(self.best_metric) or self.best_metric - metric > self.min_delta: return True, self.replace(best_metric=metric, patience_count=0) else: should_stop = self.patience_count >= self.patience or self.should_stop return False, self.replace(patience_count=self.patience_count + 1, should_stop=should_stop)