Skip to content

Commit

Permalink
added has_improved field to EarlyStopping
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Oct 6, 2023
1 parent 97e2fea commit 21c41a5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
19 changes: 11 additions & 8 deletions flax/training/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EarlyStopping(struct.PyTreeNode):
rng, input_rng = jax.random.split(rng)
optimizer, train_metrics = train_epoch(
optimizer, train_ds, config.batch_size, epoch, input_rng)
_, early_stop = early_stop.update(train_metrics['loss'])
early_stop = early_stop.update(train_metrics['loss'])
if early_stop.should_stop:
print('Met early stopping criteria, breaking...')
break
Expand All @@ -43,35 +43,38 @@ class EarlyStopping(struct.PyTreeNode):
patience_count: Number of steps since last improving update.
should_stop: Whether the training loop should stop to avoid
overfitting.
has_improved: Whether the metric has improved greater or
equal to the min_delta in the last `.update` call.
"""

min_delta: float = 0
patience: int = 0
best_metric: float = float('inf')
patience_count: int = 0
should_stop: bool = False
has_improved: bool = False

def reset(self):
return self.replace(
best_metric=float('inf'), patience_count=0, should_stop=False
best_metric=float('inf'), patience_count=0, should_stop=False, has_improved=False
)

def update(self, metric):
"""Update the state based on metric.
Returns:
A pair (has_improved, early_stop), where `has_improved` is True when there
was an improvement greater than `min_delta` from the previous
`best_metric` and `early_stop` is the updated `EarlyStop` object.
The updated EarlyStopping class. The `.has_improved` attribute is True
when there was an improvement greater than `min_delta` from the previous
`best_metric`.
"""

if (
math.isinf(self.best_metric)
or self.best_metric - metric > self.min_delta
):
return True, self.replace(best_metric=metric, patience_count=0)
return self.replace(best_metric=metric, patience_count=0, has_improved=True)
else:
should_stop = self.patience_count >= self.patience or self.should_stop
return False, self.replace(
patience_count=self.patience_count + 1, should_stop=should_stop
return self.replace(
patience_count=self.patience_count + 1, should_stop=should_stop, has_improved=False
)
16 changes: 8 additions & 8 deletions tests/early_stopping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_update(self):
improve_steps = 0
for step in range(10):
metric = 1.0
did_improve, es = es.update(metric)
if not did_improve:
es = es.update(metric)
if not es.has_improved:
improve_steps += 1
if es.should_stop:
break
Expand All @@ -51,15 +51,15 @@ def test_patience(self):
patient_es = early_stopping.EarlyStopping(min_delta=0, patience=6)
for step in range(10):
metric = 1.0
did_improve, es = es.update(metric)
es = es.update(metric)
if es.should_stop:
break

self.assertEqual(step, 1)

for patient_step in range(10):
metric = 1.0
did_improve, patient_es = patient_es.update(metric)
patient_es = patient_es.update(metric)
if patient_es.should_stop:
break

Expand All @@ -72,7 +72,7 @@ def test_delta(self):
metric = 1.0
for step in range(100):
metric -= 1e-4
did_improve, es = es.update(metric)
es = es.update(metric)
if es.should_stop:
break

Expand All @@ -81,7 +81,7 @@ def test_delta(self):
metric = 1.0
for step in range(100):
metric -= 1e-4
did_improve, delta_es = delta_es.update(metric)
delta_es = delta_es.update(metric)
if delta_es.should_stop:
break

Expand All @@ -102,8 +102,8 @@ def test_delta(self):
improvement_steps = 0
for step in range(10):
metric = metrics[step]
did_improve, delta_patient_es = delta_patient_es.update(metric)
if did_improve:
delta_patient_es = delta_patient_es.update(metric)
if delta_patient_es.has_improved:
improvement_steps += 1
if delta_patient_es.should_stop:
break
Expand Down

0 comments on commit 21c41a5

Please sign in to comment.