From 7330b8cf90b471093518f797de09b68e02cbbdf2 Mon Sep 17 00:00:00 2001 From: Haoyang Ma Date: Thu, 28 Dec 2023 11:10:56 +0800 Subject: [PATCH] fix an op convertion error (#469) Address this [todo](https://github.com/siliconflow/onediff/pull/458#discussion_r1436227401) --- recordconda | 0 .../transform/builtin_transform.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) create mode 100644 recordconda diff --git a/recordconda b/recordconda new file mode 100644 index 000000000..e69de29bb diff --git a/src/onediff/infer_compiler/transform/builtin_transform.py b/src/onediff/infer_compiler/transform/builtin_transform.py index c7f0c6c11..8ddc5c10b 100644 --- a/src/onediff/infer_compiler/transform/builtin_transform.py +++ b/src/onediff/infer_compiler/transform/builtin_transform.py @@ -9,7 +9,6 @@ import torch import oneflow as flow - from .manager import transform_mgr from ..utils.log_utils import logger from ..utils.patch_for_diffusers import diffusers_checker @@ -325,14 +324,15 @@ def _(mod: None, verbose=False): def _(mod: types.BuiltinFunctionType, verbose=False): if hasattr(mod, "__module__"): mod_name = None - #TODO: This solution is a compromise for now. - #TODO: Should register nn.linear later to solve it elegantly - if mod == torch._C._nn.linear: - return flow.nn.functional.linear - elif mod.__module__.startswith("torch._C._nn"): - mod_name = mod.__module__.replace( - "torch._C._nn", "oneflow._oneflow_internal._C" - ) + if mod.__module__.startswith("torch._C._nn"): + # The equivalence of mod inside torch._C._nn may be + # defined in flow.nn.functional + if getattr(flow.nn.functional, mod.__name__): + mod_name = "oneflow.nn.functional" + else: + mod_name = mod.__module__.replace( + "torch._C._nn", "oneflow._oneflow_internal._C" + ) elif mod.__module__.startswith("torch"): try: if getattr(torch.nn.functional, mod.__name__) == mod: