diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8f80d7fa42f791..5a2888a12aafab 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -859,6 +859,10 @@ def flatten_output(output): f"serialized model {i}th output doesn't match model {i}th output for {model_class}", ) + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + def test_headmasking(self): if not self.test_head_masking: return