diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 52ee3f675fea3..4628184f4f603 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -120,7 +120,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed validation of parameters of `plugins.precision.MixedPrecision` ([#17687](https://github.com/Lightning-AI/lightning/pull/17687)) -- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788)) + +- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788)) + + +- Fixed the emission of a false-positive warning when calling a method on the Fabric-wrapped module that accepts no arguments ([#17875](https://github.com/Lightning-AI/lightning/pull/17875)) ## [2.0.3] - 2023-06-07 diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 5375ac9c1ad27..e9dee26c88a7a 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -161,7 +161,11 @@ def call_forward_module(*args: Any, **kwargs: Any) -> Any: return call_forward_module def _validate_method_access(self, name: str, attribute: Any) -> None: - if inspect.ismethod(attribute) and self._forward_module != self._original_module: + if ( + inspect.ismethod(attribute) + and inspect.signature(attribute).parameters + and self._forward_module != self._original_module + ): warning_cache.warn( f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the" " model. This will bypass the wrapper from the strategy and result in incorrect behavior in" diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index ec8bf01bee36c..79b0a33c576f5 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -78,9 +78,12 @@ def test_fabric_module_method_lookup(): """Test that access to methods warns about improper use when a wrapper from a strategy is involved.""" class OriginalModule(torch.nn.Module): - def method(self): + def method_no_args(self): return 100 + def method_with_args(self, arg, kwarg=1): + return 101 + class ModuleWrapper(torch.nn.Module): def __init__(self, module): super().__init__() @@ -91,16 +94,18 @@ def __init__(self, module): fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module) warning_cache.clear() with no_warning_call(UserWarning): - assert fabric_module.method() == 100 + assert fabric_module.method_with_args(0) == 101 assert not warning_cache - # Special case: original module wrapped by forward module: -> warn + # Special case: original module wrapped by forward module: -> warn if method accepts args original_module = OriginalModule() wrapped_module = ModuleWrapper(original_module) fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module) warning_cache.clear() - with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method\(\)` from outside the"): - assert fabric_module.method() == 100 + with no_warning_call(UserWarning): + assert fabric_module.method_no_args() == 100 + with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"): + assert fabric_module.method_with_args(0) == 101 warning_cache.clear()