diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index 3c68b0107..e1350cd69 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -12,7 +12,7 @@ python_requires=">=3.7.0", install_requires=[ "transformers>=4.27.1", - "diffusers>=0.24.0,<=0.25.1", + "diffusers>=0.24.0,<=0.26.2", "accelerate", "torch", "onefx", diff --git a/src/infer_compiler_registry/register_diffusers/__init__.py b/src/infer_compiler_registry/register_diffusers/__init__.py index f0130d576..fc07f98ab 100644 --- a/src/infer_compiler_registry/register_diffusers/__init__.py +++ b/src/infer_compiler_registry/register_diffusers/__init__.py @@ -25,9 +25,15 @@ from diffusers.models.resnet import SpatioTemporalResBlock from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel from diffusers.models.attention import TemporalBasicTransformerBlock - from diffusers.models.unet_spatio_temporal_condition import ( - UNetSpatioTemporalConditionModel, - ) + + if diffusers_version >= version.parse("0.26.00"): + from diffusers.models.unets.unet_spatio_temporal_condition import ( + UNetSpatioTemporalConditionModel, + ) + else: + from diffusers.models.unet_spatio_temporal_condition import ( + UNetSpatioTemporalConditionModel, + ) if diffusers_version >= version.parse("0.25.00"): from diffusers.models.autoencoders.autoencoder_kl_temporal_decoder import ( diff --git a/src/onediff/infer_compiler/transform/builtin_transform.py b/src/onediff/infer_compiler/transform/builtin_transform.py index cd0faf25b..0043a7df1 100644 --- a/src/onediff/infer_compiler/transform/builtin_transform.py +++ b/src/onediff/infer_compiler/transform/builtin_transform.py @@ -3,6 +3,7 @@ import os import importlib import types +import inspect from functools import singledispatch, partial from collections import OrderedDict from collections.abc import Iterable @@ -160,10 +161,6 @@ def _(mod: type): def default_converter(obj, verbose=False, *, proxy_cls=None): - # Higher versions of diffusers might use torch.nn.modules.Linear - if obj is torch.nn.Linear: - return flow.nn.Linear - if not is_need_mock(type(obj)): return obj try: @@ -211,7 +208,6 @@ def init(self): attr = getattr(proxy_md, k) try: self.__dict__[k] = torch2oflow(attr) - except Exception as e: logger.error(f"convert {type(attr)} failed: {e}") raise NotImplementedError(f"Unsupported type: {type(attr)}") diff --git a/src/onediff/infer_compiler/transform/custom_transform.py b/src/onediff/infer_compiler/transform/custom_transform.py index 32eb3a774..0d0e71f59 100644 --- a/src/onediff/infer_compiler/transform/custom_transform.py +++ b/src/onediff/infer_compiler/transform/custom_transform.py @@ -1,5 +1,6 @@ """A module for registering custom torch2oflow functions and classes.""" import inspect +import importlib.util from pathlib import Path from typing import Callable, Dict, List, Optional, Union from ..import_tools import import_module_from_path @@ -50,16 +51,19 @@ def import_module_safely(module_path, module_name): # compiler_registry_path registry_path = Path(__file__).parents[3] / "infer_compiler_registry" - import_module_safely(registry_path / "register_diffusers", "register_diffusers") + if importlib.util.find_spec("diffusers") is not None: + import_module_safely(registry_path / "register_diffusers", "register_diffusers") - import_module_safely( - registry_path / "register_onediff_quant", "register_onediff_quant" - ) + if importlib.util.find_spec("onediff_quant") is not None: + import_module_safely( + registry_path / "register_onediff_quant", "register_onediff_quant" + ) - import_module_safely( - registry_path / "register_diffusers_enterprise_lite", - "register_diffusers_enterprise_lite", - ) + if importlib.util.find_spec("diffusers_enterprise_lite") is not None: + import_module_safely( + registry_path / "register_diffusers_enterprise_lite", + "register_diffusers_enterprise_lite", + ) def ensure_list(obj): diff --git a/src/onediff/infer_compiler/transform/manager.py b/src/onediff/infer_compiler/transform/manager.py index 672e7fb49..03c2c8322 100644 --- a/src/onediff/infer_compiler/transform/manager.py +++ b/src/onediff/infer_compiler/transform/manager.py @@ -26,7 +26,7 @@ def __init__(self, debug_mode=False, tmp_dir="./output"): def _setup_logger(self): name = "ONEDIFF" - level = logging.DEBUG if self.debug_mode else logging.ERROR + level = logging.DEBUG if self.debug_mode else logging.WARNING logger.configure_logging(name=name, file_name=None, level=level, log_dir=None) self.logger = logger diff --git a/src/onediff/infer_compiler/utils/log_utils.py b/src/onediff/infer_compiler/utils/log_utils.py index cfb6ac42f..9e9892209 100644 --- a/src/onediff/infer_compiler/utils/log_utils.py +++ b/src/onediff/infer_compiler/utils/log_utils.py @@ -37,7 +37,7 @@ def configure_logging(self, name, level, log_dir=None, file_name=None): # Create a console formatter and add it to a console handler console_formatter = ColorFormatter( - fmt="%(levelname)s [%(asctime)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + fmt="%(levelname)s [%(asctime)s] %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) console_handler = logging.StreamHandler() console_handler.setFormatter(console_formatter)