Skip to content

Commit

Permalink
Supporting obj is not an object. (#621)
Browse files Browse the repository at this point in the history
测试:
```shell
from onediff.infer_compiler.transform import torch2oflow
import torch 
mod = torch.nn.Linear 
# Convert to oneflow
of_mod = torch2oflow(mod)
print(of_mod) # <class 'oneflow.nn.modules.linear.Linear'>
x = 9 
print(f'{isinstance(x, type)=}') # isinstance(x, type)=False
```
  • Loading branch information
ccssu authored Feb 5, 2024
1 parent 42069f7 commit 8a52a7e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
2 changes: 2 additions & 0 deletions src/onediff/infer_compiler/import_tools/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def is_need_mock(cls) -> bool:
assert isinstance(cls, (type, str))
main_pkg = cls.__module__.split(".")[0]
try:
if main_pkg == "torch":
return True
pkgs = requires(main_pkg)
except Exception as e:
return True
Expand Down
63 changes: 38 additions & 25 deletions src/onediff/infer_compiler/transform/builtin_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Convert torch object to oneflow object."""

import os
import importlib
import types
Expand Down Expand Up @@ -105,7 +106,10 @@ def __getattribute__(self, attribute):
):
return flow.Generator()
elif (
isinstance(self._oflow_proxy_submod, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d))
isinstance(
self._oflow_proxy_submod,
(torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d),
)
and attribute == "channel_pos"
):
return "channels_first"
Expand Down Expand Up @@ -148,6 +152,13 @@ def torch2oflow(mod, *args, **kwargs):
return default_converter(mod, *args, **kwargs)


@torch2oflow.register
def _(mod: type):
if not is_need_mock(mod):
return mod
return proxy_class(mod)


def default_converter(obj, verbose=False, *, proxy_cls=None):
# Higher versions of diffusers might use torch.nn.modules.Linear
if obj is torch.nn.Linear:
Expand All @@ -162,7 +173,7 @@ def init(self):
for k, _ in obj.__dict__.items():
attr = getattr(obj, k)
self.__dict__[k] = torch2oflow(attr)

of_obj_cls = type(str(new_obj_cls), (new_obj_cls,), {"__init__": init})
of_obj = of_obj_cls()

Expand All @@ -174,11 +185,12 @@ def init(self):
# raise NotImplementedError(f"Unsupported type: {obj}")
return obj


@torch2oflow.register
def _(mod: torch.nn.Module, verbose=False):
proxy_md = ProxySubmodule(mod)
new_md_cls = proxy_class(type(mod))

def init(self):
nonlocal proxy_md

Expand Down Expand Up @@ -223,7 +235,7 @@ def proxy_getattr(self, attr):
str(new_md_cls), (new_md_cls,), {"__init__": init, "__getattr__": proxy_getattr}
)
of_mod = of_mod_cls()

if of_mod.training:
of_mod.training = False
if verbose:
Expand All @@ -244,7 +256,7 @@ def proxy_getattr(self, attr):
def _(mod: torch.nn.BatchNorm1d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_axis = 1

return of_mod


Expand All @@ -254,7 +266,7 @@ def _(mod: torch.nn.BatchNorm2d, verbose=False):
if os.getenv("ONEFLOW_ENABLE_NHWC"):
of_mod.channel_axis = 3
else:
of_mod.channel_axis = 1
of_mod.channel_axis = 1

return of_mod

Expand All @@ -263,72 +275,72 @@ def _(mod: torch.nn.BatchNorm2d, verbose=False):
def _(mod: torch.nn.BatchNorm3d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_axis = 1

return of_mod


@torch2oflow.register
def _(mod: torch.nn.MaxPool1d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.MaxPool2d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
if os.getenv("ONEFLOW_ENABLE_NHWC"):
of_mod.channel_pos = 'channels_last'
of_mod.channel_pos = "channels_last"
else:
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.MaxPool3d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.AvgPool1d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.AvgPool2d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
if os.getenv("ONEFLOW_ENABLE_NHWC"):
of_mod.channel_pos = 'channels_last'
of_mod.channel_pos = "channels_last"
else:
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.AvgPool3d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


@torch2oflow.register
def _(mod: torch.nn.AdaptiveAvgPool2d, verbose=False):
of_mod = torch2oflow.dispatch(torch.nn.Module)(mod, verbose)
if os.getenv("ONEFLOW_ENABLE_NHWC"):
of_mod.channel_pos = 'channels_last'
of_mod.channel_pos = "channels_last"
else:
of_mod.channel_pos = 'channels_first'
of_mod.channel_pos = "channels_first"

return of_mod


Expand All @@ -350,7 +362,7 @@ def _(mod: torch.nn.Sequential, verbose=False):
of_mod_list.append(submod)

of_mod_seq = proxy_class(type(mod))(*of_mod_list)

return of_mod_seq


Expand Down Expand Up @@ -419,6 +431,7 @@ def _(mod, verbose=False) -> Union[int, float, str, bool]:
def _(mod: None, verbose=False):
return mod


@torch2oflow.register
def _(mod: types.BuiltinFunctionType, verbose=False):
if hasattr(mod, "__module__"):
Expand Down

0 comments on commit 8a52a7e

Please sign in to comment.