From 8a52a7e70367cc187f80e08b2507f00f894d3699 Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Mon, 5 Feb 2024 14:44:57 +0800 Subject: [PATCH] Supporting obj is not an object. (#621) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 测试: ```shell from onediff.infer_compiler.transform import torch2oflow import torch mod = torch.nn.Linear # Convert to oneflow of_mod = torch2oflow(mod) print(of_mod) # x = 9 print(f'{isinstance(x, type)=}') # isinstance(x, type)=False ``` --- .../infer_compiler/import_tools/importer.py | 2 + .../transform/builtin_transform.py | 63 +++++++++++-------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/src/onediff/infer_compiler/import_tools/importer.py b/src/onediff/infer_compiler/import_tools/importer.py index dbac5f87a..1344e2168 100644 --- a/src/onediff/infer_compiler/import_tools/importer.py +++ b/src/onediff/infer_compiler/import_tools/importer.py @@ -15,6 +15,8 @@ def is_need_mock(cls) -> bool: assert isinstance(cls, (type, str)) main_pkg = cls.__module__.split(".")[0] try: + if main_pkg == "torch": + return True pkgs = requires(main_pkg) except Exception as e: return True diff --git a/src/onediff/infer_compiler/transform/builtin_transform.py b/src/onediff/infer_compiler/transform/builtin_transform.py index 6be671d2f..cd0faf25b 100644 --- a/src/onediff/infer_compiler/transform/builtin_transform.py +++ b/src/onediff/infer_compiler/transform/builtin_transform.py @@ -1,4 +1,5 @@ """Convert torch object to oneflow object.""" + import os import importlib import types @@ -105,7 +106,10 @@ def __getattribute__(self, attribute): ): return flow.Generator() elif ( - isinstance(self._oflow_proxy_submod, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) + isinstance( + self._oflow_proxy_submod, + (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d), + ) and attribute == "channel_pos" ): return "channels_first" @@ -148,6 +152,13 @@ def torch2oflow(mod, *args, **kwargs): return default_converter(mod, *args, **kwargs) +@torch2oflow.register +def _(mod: type): + if not is_need_mock(mod): + return mod + return proxy_class(mod) + + def default_converter(obj, verbose=False, *, proxy_cls=None): # Higher versions of diffusers might use torch.nn.modules.Linear if obj is torch.nn.Linear: @@ -162,7 +173,7 @@ def init(self): for k, _ in obj.__dict__.items(): attr = getattr(obj, k) self.__dict__[k] = torch2oflow(attr) - + of_obj_cls = type(str(new_obj_cls), (new_obj_cls,), {"__init__": init}) of_obj = of_obj_cls() @@ -174,11 +185,12 @@ def init(self): # raise NotImplementedError(f"Unsupported type: {obj}") return obj + @torch2oflow.register def _(mod: torch.nn.Module, verbose=False): proxy_md = ProxySubmodule(mod) new_md_cls = proxy_class(type(mod)) - + def init(self): nonlocal proxy_md @@ -223,7 +235,7 @@ def proxy_getattr(self, attr): str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr} ) of_mod = of_mod_cls() - + if of_mod.training: of_mod.training = False if verbose: @@ -244,7 +256,7 @@ def proxy_getattr(self, attr): def _(mod: torch.nn.BatchNorm1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_axis = 1 - + return of_mod @@ -254,7 +266,7 @@ def _(mod: torch.nn.BatchNorm2d, verbose=False): if os.getenv("ONEFLOW_ENABLE_NHWC"): of_mod.channel_axis = 3 else: - of_mod.channel_axis = 1 + of_mod.channel_axis = 1 return of_mod @@ -263,15 +275,15 @@ def _(mod: torch.nn.BatchNorm2d, verbose=False): def _(mod: torch.nn.BatchNorm3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) of_mod.channel_axis = 1 - + return of_mod @torch2oflow.register def _(mod: torch.nn.MaxPool1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @@ -279,26 +291,26 @@ def _(mod: torch.nn.MaxPool1d, verbose=False): def _(mod: torch.nn.MaxPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): - of_mod.channel_pos = 'channels_last' + of_mod.channel_pos = "channels_last" else: - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @torch2oflow.register def _(mod: torch.nn.MaxPool3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @torch2oflow.register def _(mod: torch.nn.AvgPool1d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @@ -306,18 +318,18 @@ def _(mod: torch.nn.AvgPool1d, verbose=False): def _(mod: torch.nn.AvgPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): - of_mod.channel_pos = 'channels_last' + of_mod.channel_pos = "channels_last" else: - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @torch2oflow.register def _(mod: torch.nn.AvgPool3d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @@ -325,10 +337,10 @@ def _(mod: torch.nn.AvgPool3d, verbose=False): def _(mod: torch.nn.AdaptiveAvgPool2d, verbose=False): of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose) if os.getenv("ONEFLOW_ENABLE_NHWC"): - of_mod.channel_pos = 'channels_last' + of_mod.channel_pos = "channels_last" else: - of_mod.channel_pos = 'channels_first' - + of_mod.channel_pos = "channels_first" + return of_mod @@ -350,7 +362,7 @@ def _(mod: torch.nn.Sequential, verbose=False): of_mod_list.append(submod) of_mod_seq = proxy_class(type(mod))(*of_mod_list) - + return of_mod_seq @@ -419,6 +431,7 @@ def _(mod, verbose=False) -> Union[int, float, str, bool]: def _(mod: None, verbose=False): return mod + @torch2oflow.register def _(mod: types.BuiltinFunctionType, verbose=False): if hasattr(mod, "__module__"):