diff --git a/CHANGELOG.md b/CHANGELOG.md index 23b7d535..38b5af2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed issue with `is_overridden` falsely returning True when the parent method is wrapped ([#149](https://github.com/Lightning-AI/utilities/pull/149)) ## [0.9.0] - 2023-06-29 diff --git a/src/lightning_utilities/core/overrides.py b/src/lightning_utilities/core/overrides.py index f3eb5610..0eb74dcf 100644 --- a/src/lightning_utilities/core/overrides.py +++ b/src/lightning_utilities/core/overrides.py @@ -12,7 +12,7 @@ def is_overridden(method_name: str, instance: object, parent: Type[object]) -> b instance_attr = getattr(instance, method_name, None) if instance_attr is None: return False - # `functools.wraps()` support + # `functools.wraps()` and `@contextmanager` support if hasattr(instance_attr, "__wrapped__"): instance_attr = instance_attr.__wrapped__ # `Mock(wraps=...)` support @@ -28,5 +28,8 @@ def is_overridden(method_name: str, instance: object, parent: Type[object]) -> b parent_attr = getattr(parent, method_name, None) if parent_attr is None: raise ValueError("The parent should define the method") + # `@contextmanager` support + if hasattr(parent_attr, "__wrapped__"): + parent_attr = parent_attr.__wrapped__ return instance_attr.__code__ != parent_attr.__code__ diff --git a/tests/unittests/core/test_overrides.py b/tests/unittests/core/test_overrides.py index 9c580791..d1136b60 100644 --- a/tests/unittests/core/test_overrides.py +++ b/tests/unittests/core/test_overrides.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from functools import partial, wraps from typing import Any, Callable from unittest.mock import Mock @@ -16,6 +17,16 @@ def training_step(self): ... +class Strategy: + @contextmanager + def model_sharded_context(): + ... + + +class SingleDeviceStrategy(Strategy): + ... + + def test_is_overridden(): assert not is_overridden("whatever", object(), parent=LightningModule) @@ -65,3 +76,6 @@ def bar(self): model = BoringModel() model.training_step = partial(model.training_step) assert is_overridden("training_step", model, parent=LightningModule) + + # `@contextmanager` support + assert not is_overridden("model_sharded_context", SingleDeviceStrategy(), Strategy)