Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce false positive warnings when calling module methods in Fabric #17875

Merged
merged 4 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed validation of parameters of `plugins.precision.MixedPrecision` ([#17687](https://github.com/Lightning-AI/lightning/pull/17687))

- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788))

- Fixed an issue with hpu imports leading to performance degradation ([#17788](https://github.com/Lightning-AI/lightning/pull/17788))


- Fixed the emission of a false-positive warning when calling a method on the Fabric-wrapped module that accepts no arguments ([#17875](https://github.com/Lightning-AI/lightning/pull/17875))


## [2.0.3] - 2023-06-07
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ def call_forward_module(*args: Any, **kwargs: Any) -> Any:
return call_forward_module

def _validate_method_access(self, name: str, attribute: Any) -> None:
if inspect.ismethod(attribute) and self._forward_module != self._original_module:
if (
inspect.ismethod(attribute)
and inspect.signature(attribute).parameters
and self._forward_module != self._original_module
):
warning_cache.warn(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
Expand Down
15 changes: 10 additions & 5 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,12 @@ def test_fabric_module_method_lookup():
"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""

class OriginalModule(torch.nn.Module):
def method(self):
def method_no_args(self):
return 100

def method_with_args(self, arg, kwarg=1):
return 101

class ModuleWrapper(torch.nn.Module):
def __init__(self, module):
super().__init__()
Expand All @@ -91,16 +94,18 @@ def __init__(self, module):
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with no_warning_call(UserWarning):
assert fabric_module.method() == 100
assert fabric_module.method_with_args(0) == 101
assert not warning_cache

# Special case: original module wrapped by forward module: -> warn
# Special case: original module wrapped by forward module: -> warn if method accepts args
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method\(\)` from outside the"):
assert fabric_module.method() == 100
with no_warning_call(UserWarning):
assert fabric_module.method_no_args() == 100
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"):
assert fabric_module.method_with_args(0) == 101
warning_cache.clear()


Expand Down