diff --git a/eole/trainer.py b/eole/trainer.py index 16f5cf16..c136f4d7 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -469,6 +469,7 @@ def validate(self, valid_iter, moving_average=None): # Update statistics. stats.update(metric_stats) + valid_model.decoder._clear_cache() if moving_average: for param_data, param in zip(model_params_data, self.model.parameters()):