From fb6a37d1a5e8f3f3eb4266e54b27eac3299215ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 04:07:15 +0100 Subject: [PATCH 1/5] Fix monkeypatching of fabric modules --- src/lightning/fabric/wrappers.py | 2 +- tests/tests_fabric/test_wrappers.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 11f1c67211e40..7c0f0d0720f08 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None: original_has_attr = hasattr(original_module, name) # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules # Can't use self.__getattr__ because it would pass through to the original module - fabric_has_attr = name in self.__dict__ + fabric_has_attr = hasattr(self, name) if not (original_has_attr or fabric_has_attr): setattr(original_module, name, value) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 3d6e47bffa8c5..512fa5861d5e9 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -170,6 +170,13 @@ def __init__(self): assert linear in fabric_module.modules() assert linear in original_module.modules() + # Check monkeypatching of methods + model = _FabricModule(Mock(), Mock()) + assert isinstance(model, _FabricModule) + original = id(model.forward) + model.forward = lambda *_: None + assert id(model.forward) != original + def test_fabric_module_state_dict_access(): """Test that state_dict access passes through to the original module.""" From ca34d079a202bb38735f4bd5c31f5e368e7719bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 04:10:37 +0100 Subject: [PATCH 2/5] CHANGELOG --- src/lightning/fabric/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 0c7af66bceb98..d88f2ec12827a 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -49,7 +49,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue causing a TypeError when using `torch.compile` as a decorator ([#19627](https://github.com/Lightning-AI/pytorch-lightning/pull/19627)) -- +- Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped ([#19705](https://github.com/Lightning-AI/pytorch-lightning/pull/19705)) - From 319ca42610e96a115b7a92cb4a8852b9b524836d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 05:20:37 +0100 Subject: [PATCH 3/5] Alternative --- src/lightning/fabric/wrappers.py | 2 +- tests/tests_fabric/test_wrappers.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 7c0f0d0720f08..cefaa300a1611 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None: original_has_attr = hasattr(original_module, name) # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules # Can't use self.__getattr__ because it would pass through to the original module - fabric_has_attr = hasattr(self, name) + fabric_has_attr = name in self.__dict__ or name in type(self).__dict__ if not (original_has_attr or fabric_has_attr): setattr(original_module, name, value) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 512fa5861d5e9..7ec67d43646cc 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -155,6 +155,9 @@ def __init__(self): # Modify existing attribute on original_module fabric_module.attribute = 101 + # "attribute" is only in the original_module, so it shouldn't get set in the fabric_module + assert "attribute" not in fabric_module.__dict__ + assert fabric_module.attribute == 101 # returns it from original_module assert original_module.attribute == 101 # Check setattr of original_module @@ -172,7 +175,6 @@ def __init__(self): # Check monkeypatching of methods model = _FabricModule(Mock(), Mock()) - assert isinstance(model, _FabricModule) original = id(model.forward) model.forward = lambda *_: None assert id(model.forward) != original From d23fb1332998c12d85eae587fb43516179e187dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 05:38:47 +0100 Subject: [PATCH 4/5] dir() --- src/lightning/fabric/wrappers.py | 2 +- tests/tests_fabric/test_wrappers.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index cefaa300a1611..093b355e2c376 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -263,7 +263,7 @@ def __setattr__(self, name: str, value: Any) -> None: original_has_attr = hasattr(original_module, name) # Can't use super().__getattr__ because nn.Module only checks _parameters, _buffers, and _modules # Can't use self.__getattr__ because it would pass through to the original module - fabric_has_attr = name in self.__dict__ or name in type(self).__dict__ + fabric_has_attr = name in dir(self) if not (original_has_attr or fabric_has_attr): setattr(original_module, name, value) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 7ec67d43646cc..a0e746538650f 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -174,10 +174,21 @@ def __init__(self): assert linear in original_module.modules() # Check monkeypatching of methods - model = _FabricModule(Mock(), Mock()) - original = id(model.forward) - model.forward = lambda *_: None - assert id(model.forward) != original + fabric_module = _FabricModule(Mock(), Mock()) + original = id(fabric_module.forward) + fabric_module.forward = lambda *_: None + assert id(fabric_module.forward) != original + # Check special methods + assert "__repr__" in dir(fabric_module) + assert "__repr__" not in fabric_module.__dict__ + assert "__repr__" not in _FabricModule.__dict__ + fabric_module.__repr__ = lambda *_: "test" + assert fabric_module.__repr__() == "test" + # needs to be monkeypatched on the class for `repr()` to take change + assert repr(fabric_module) == "_FabricModule()" + with mock.patch.object(_FabricModule, "__repr__", return_value="test"): + assert fabric_module.__repr__() == "test" + assert repr(fabric_module) == "test" def test_fabric_module_state_dict_access(): From a7c4cd69a17679a5eedb858bc56810d04f234297 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 05:39:34 +0100 Subject: [PATCH 5/5] typo --- tests/tests_fabric/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index a0e746538650f..0923c601d51c3 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -184,7 +184,7 @@ def __init__(self): assert "__repr__" not in _FabricModule.__dict__ fabric_module.__repr__ = lambda *_: "test" assert fabric_module.__repr__() == "test" - # needs to be monkeypatched on the class for `repr()` to take change + # needs to be monkeypatched on the class for `repr()` to change assert repr(fabric_module) == "_FabricModule()" with mock.patch.object(_FabricModule, "__repr__", return_value="test"): assert fabric_module.__repr__() == "test"