diff --git a/metaflow/extension_support.py b/metaflow/extension_support.py index 829889d5960..38b0480721b 100644 --- a/metaflow/extension_support.py +++ b/metaflow/extension_support.py @@ -924,15 +924,22 @@ def exec_module(self, module): class _OrigLoader(Loader): - def __init__(self, fullname, orig_loader, previously_loaded_module=None): + def __init__( + self, + fullname, + orig_loader, + previously_loaded_module=None, + previously_loaded_parent_module=None, + ): self._fullname = fullname self._orig_loader = orig_loader self._previously_loaded_module = previously_loaded_module + self._previously_loaded_parent_module = previously_loaded_parent_module def create_module(self, spec): _ext_debug( - "Loading original module '%s' (will be loaded at '%s')" - % (spec.name, self._fullname) + "Loading original module '%s' (will be loaded at '%s'); spec is %s" + % (spec.name, self._fullname, str(spec)) ) self._orig_name = spec.name return self._orig_loader.create_module(spec) @@ -948,15 +955,20 @@ def exec_module(self, module): module.__spec__.name = self._fullname module.__orig_name__ = module.__name__ module.__name__ = self._fullname + module.__package__ = module.__spec__.parent # assumption since 3.6 sys.modules[self._fullname] = module del sys.modules[self._orig_name] finally: # At this point, the original module is loaded with the original name. We - # want to replace it with previously_loaded_module if it exists and, in both - # cases, change the name to fullname (which has _orig in it) + # want to replace it with previously_loaded_module if it exists. We + # also replace the parent properly if self._previously_loaded_module: sys.modules[self._orig_name] = self._previously_loaded_module + if self._previously_loaded_parent_module: + sys.modules[ + ".".join(self._orig_name.split(".")[:-1]) + ] = self._previously_loaded_parent_module class _LazyFinder(MetaPathFinder): @@ -974,9 +986,17 @@ def __init__(self, handled): # the over-ridden module self._temp_excluded_prefix = set() + # This is used to determine if we should be searching in _orig modules. Basically, + # when a relative import is done from a module in _orig, we want to search in + # the _orig "tree" + self._orig_search_paths = set() + def find_spec(self, fullname, path, target=None): # If we are trying to load a shadowed module (ending in ._orig), we don't # say we handle it + _ext_debug( + "Looking for %s in %s with target %s" % (fullname, str(path), target) + ) if any([fullname.startswith(e) for e in self._temp_excluded_prefix]): return None @@ -998,16 +1018,30 @@ def find_spec(self, fullname, path, target=None): orig_idx = -1 if orig_idx > -1 and ".".join(name_parts[:orig_idx]) in self._handled: orig_name = ".".join(name_parts[:orig_idx] + name_parts[orig_idx + 1 :]) + parent_name = None + if orig_idx != len(name_parts) - 1: + # We have a parent module under the _orig portion so for example, if + # we load mymodule._orig.orig_submodule, our parent is mymodule._orig. + # However, since mymodule is currently shadowed, we need to reset + # the parent module properly. We know it is already loaded (since modules + # are loaded hierarchically) + parent_name = ".".join( + name_parts[:orig_idx] + name_parts[orig_idx + 1 : -1] + ) _ext_debug("Looking for original module '%s'" % orig_name) prefix = ".".join(name_parts[:orig_idx]) self._temp_excluded_prefix.add(prefix) # We also have to remove the module temporarily while we look for the # new spec since otherwise it returns the spec of that loaded module. # module is also restored *after* we call `create_module` in the loader - # otherwise it just returns None + # otherwise it just returns None. We also swap out the parent module so that + # the search can start from there. loaded_module = sys.modules.get(orig_name) if loaded_module: del sys.modules[orig_name] + parent_module = sys.modules.get(parent_name) if parent_name else None + if parent_module: + sys.modules[parent_name] = sys.modules[".".join([parent_name, "_orig"])] # This finds the spec that would have existed had we not added all our # _LazyFinders @@ -1018,13 +1052,32 @@ def find_spec(self, fullname, path, target=None): if not spec: return None + if spec.submodule_search_locations: + self._orig_search_paths.update(spec.submodule_search_locations) + _ext_debug("Found original spec %s" % spec) # Change the spec - spec.loader = _OrigLoader(fullname, spec.loader, loaded_module) + spec.loader = _OrigLoader( + fullname, + spec.loader, + loaded_module, + parent_module, + ) return spec + for p in path or []: + if p in self._orig_search_paths: + # We need to look in some of the "_orig" modules + orig_override_name = ".".join( + name_parts[:-1] + ["_orig", name_parts[-1]] + ) + _ext_debug( + "Looking for %s as an original module: searching for %s" + % (fullname, orig_override_name) + ) + return importlib.util.find_spec(orig_override_name) if len(name_parts) > 1: # This checks for submodules of things we handle. We check for the most # specific submodule match and use that diff --git a/test/core/metaflow_extensions/test_org/plugins/frameworks/__init__.py b/test/core/metaflow_extensions/test_org/plugins/frameworks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py new file mode 100644 index 00000000000..0d9b69fa3c1 --- /dev/null +++ b/test/core/metaflow_extensions/test_org/plugins/frameworks/pytorch.py @@ -0,0 +1,8 @@ +from metaflow.plugins.frameworks._orig.pytorch import ( + PytorchParallelDecorator, + setup_torch_distributed, +) + + +class NewPytorchParallelDecorator(PytorchParallelDecorator): + pass diff --git a/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py b/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py index 69507dfddf6..4aa191affee 100644 --- a/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py +++ b/test/core/metaflow_extensions/test_org/plugins/mfextinit_test_org.py @@ -6,4 +6,4 @@ STEP_DECORATORS = [TestStepDecorator] -__mf_promote_submodules__ = ["nondecoplugin"] +__mf_promote_submodules__ = ["nondecoplugin", "frameworks"] diff --git a/test/core/tests/extensions.py b/test/core/tests/extensions.py index b760d605ea9..29e4cb472e9 100644 --- a/test/core/tests/extensions.py +++ b/test/core/tests/extensions.py @@ -16,6 +16,7 @@ def step_all(self): from metaflow.plugins.nondecoplugin import my_value from metaflow.exception import MetaflowTestException + from metaflow.plugins.frameworks.pytorch import NewPytorchParallelDecorator self.plugin_value = my_value self.tl_value = tl_value