From 9bc8efdaca6742a54476f889763b801935489920 Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Mon, 24 Oct 2022 14:15:33 -0700 Subject: [PATCH 1/3] Fix `._orig` access for submodules for MF extensions This addresses an issue where submodules like `mymodule._orig.submodule` would not load properly. This will properly load things like `from .foofa import xyz` from a module or sub-module in the original location but currently does not work if the import path is absolute. --- metaflow/extension_support.py | 67 +++++++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/metaflow/extension_support.py b/metaflow/extension_support.py index 829889d5960..38b0480721b 100644 --- a/metaflow/extension_support.py +++ b/metaflow/extension_support.py @@ -924,15 +924,22 @@ def exec_module(self, module): class _OrigLoader(Loader): - def __init__(self, fullname, orig_loader, previously_loaded_module=None): + def __init__( + self, + fullname, + orig_loader, + previously_loaded_module=None, + previously_loaded_parent_module=None, + ): self._fullname = fullname self._orig_loader = orig_loader self._previously_loaded_module = previously_loaded_module + self._previously_loaded_parent_module = previously_loaded_parent_module def create_module(self, spec): _ext_debug( - "Loading original module '%s' (will be loaded at '%s')" - % (spec.name, self._fullname) + "Loading original module '%s' (will be loaded at '%s'); spec is %s" + % (spec.name, self._fullname, str(spec)) ) self._orig_name = spec.name return self._orig_loader.create_module(spec) @@ -948,15 +955,20 @@ def exec_module(self, module): module.__spec__.name = self._fullname module.__orig_name__ = module.__name__ module.__name__ = self._fullname + module.__package__ = module.__spec__.parent # assumption since 3.6 sys.modules[self._fullname] = module del sys.modules[self._orig_name] finally: # At this point, the original module is loaded with the original name. We - # want to replace it with previously_loaded_module if it exists and, in both - # cases, change the name to fullname (which has _orig in it) + # want to replace it with previously_loaded_module if it exists. We + # also replace the parent properly if self._previously_loaded_module: sys.modules[self._orig_name] = self._previously_loaded_module + if self._previously_loaded_parent_module: + sys.modules[ + ".".join(self._orig_name.split(".")[:-1]) + ] = self._previously_loaded_parent_module class _LazyFinder(MetaPathFinder): @@ -974,9 +986,17 @@ def __init__(self, handled): # the over-ridden module self._temp_excluded_prefix = set() + # This is used to determine if we should be searching in _orig modules. Basically, + # when a relative import is done from a module in _orig, we want to search in + # the _orig "tree" + self._orig_search_paths = set() + def find_spec(self, fullname, path, target=None): # If we are trying to load a shadowed module (ending in ._orig), we don't # say we handle it + _ext_debug( + "Looking for %s in %s with target %s" % (fullname, str(path), target) + ) if any([fullname.startswith(e) for e in self._temp_excluded_prefix]): return None @@ -998,16 +1018,30 @@ def find_spec(self, fullname, path, target=None): orig_idx = -1 if orig_idx > -1 and ".".join(name_parts[:orig_idx]) in self._handled: orig_name = ".".join(name_parts[:orig_idx] + name_parts[orig_idx + 1 :]) + parent_name = None + if orig_idx != len(name_parts) - 1: + # We have a parent module under the _orig portion so for example, if + # we load mymodule._orig.orig_submodule, our parent is mymodule._orig. + # However, since mymodule is currently shadowed, we need to reset + # the parent module properly. We know it is already loaded (since modules + # are loaded hierarchically) + parent_name = ".".join( + name_parts[:orig_idx] + name_parts[orig_idx + 1 : -1] + ) _ext_debug("Looking for original module '%s'" % orig_name) prefix = ".".join(name_parts[:orig_idx]) self._temp_excluded_prefix.add(prefix) # We also have to remove the module temporarily while we look for the # new spec since otherwise it returns the spec of that loaded module. # module is also restored *after* we call `create_module` in the loader - # otherwise it just returns None + # otherwise it just returns None. We also swap out the parent module so that + # the search can start from there. loaded_module = sys.modules.get(orig_name) if loaded_module: del sys.modules[orig_name] + parent_module = sys.modules.get(parent_name) if parent_name else None + if parent_module: + sys.modules[parent_name] = sys.modules[".".join([parent_name, "_orig"])] # This finds the spec that would have existed had we not added all our # _LazyFinders @@ -1018,13 +1052,32 @@ def find_spec(self, fullname, path, target=None): if not spec: return None + if spec.submodule_search_locations: + self._orig_search_paths.update(spec.submodule_search_locations) + _ext_debug("Found original spec %s" % spec) # Change the spec - spec.loader = _OrigLoader(fullname, spec.loader, loaded_module) + spec.loader = _OrigLoader( + fullname, + spec.loader, + loaded_module, + parent_module, + ) return spec + for p in path or []: + if p in self._orig_search_paths: + # We need to look in some of the "_orig" modules + orig_override_name = ".".join( + name_parts[:-1] + ["_orig", name_parts[-1]] + ) + _ext_debug( + "Looking for %s as an original module: searching for %s" + % (fullname, orig_override_name) + ) + return importlib.util.find_spec(orig_override_name) if len(name_parts) > 1: # This checks for submodules of things we handle. We check for the most # specific submodule match and use that From 4665d5aa5d1616c2a09d82cab85dc7698aea08aa Mon Sep 17 00:00:00 2001 From: Tom Furmston Date: Tue, 15 Nov 2022 08:25:37 +0000 Subject: [PATCH 2/3] Add test that orig module is accessible (#1192) --- .../test_org/plugins/frameworks/__init__.py | 0 .../test_org/plugins/frameworks/pytorch.py | 5 +++++ .../test_org/plugins/mfextinit_test_org.py | 2 +- test/core/tests/extensions.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 test/core/metaflow_extensions/test_org/plugins/frameworks/__init__.py create mode 100644 test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py diff --git a/test/core/metaflow_extensions/test_org/plugins/frameworks/__init__.py b/test/core/metaflow_extensions/test_org/plugins/frameworks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py new file mode 100644 index 00000000000..23bfb96cdde --- /dev/null +++ b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py @@ -0,0 +1,5 @@ +from metaflow.plugins.frameworks._orig.pytorch import PytorchParallelDecorator, setup_torch_distributed + + +class NewPytorchParallelDecorator(PytorchParallelDecorator): + pass diff --git a/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py b/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py index 69507dfddf6..4aa191affee 100644 --- a/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py +++ b/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py @@ -6,4 +6,4 @@ STEP_DECORATORS = [TestStepDecorator] -__mf_promote_submodules__ = ["nondecoplugin"] +__mf_promote_submodules__ = ["nondecoplugin", "frameworks"] diff --git a/test/core/tests/extensions.py b/test/core/tests/extensions.py index b760d605ea9..29e4cb472e9 100644 --- a/test/core/tests/extensions.py +++ b/test/core/tests/extensions.py @@ -16,6 +16,7 @@ def step_all(self): from metaflow.plugins.nondecoplugin import my_value from metaflow.exception import MetaflowTestException + from metaflow.plugins.frameworks.pytorch import NewPytorchParallelDecorator self.plugin_value = my_value self.tl_value = tl_value From 204f823a6b9287031c8d67c803da0081e1c85f1c Mon Sep 17 00:00:00 2001 From: Romain Cledat Date: Tue, 15 Nov 2022 09:44:59 -0800 Subject: [PATCH 3/3] Black format --- .../test_org/plugins/frameworks/pytorch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py index 23bfb96cdde..0d9b69fa3c1 100644 --- a/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py +++ b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py @@ -1,4 +1,7 @@ -from metaflow.plugins.frameworks._orig.pytorch import PytorchParallelDecorator, setup_torch_distributed +from metaflow.plugins.frameworks._orig.pytorch import ( + PytorchParallelDecorator, + setup_torch_distributed, +) class NewPytorchParallelDecorator(PytorchParallelDecorator):