Skip to content

Commit

Permalink
fix compatibility with peft and diffusers 0.26.1 (#626)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi authored Feb 7, 2024
1 parent 18e87a7 commit 4da2270
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 19 deletions.
2 changes: 1 addition & 1 deletion onediff_diffusers_extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
python_requires=">=3.7.0",
install_requires=[
"transformers>=4.27.1",
"diffusers>=0.24.0,<=0.25.1",
"diffusers>=0.24.0,<=0.26.2",
"accelerate",
"torch",
"onefx",
Expand Down
12 changes: 9 additions & 3 deletions src/infer_compiler_registry/register_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
from diffusers.models.attention import TemporalBasicTransformerBlock
from diffusers.models.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)

if diffusers_version >= version.parse("0.26.00"):
from diffusers.models.unets.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)
else:
from diffusers.models.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)

if diffusers_version >= version.parse("0.25.00"):
from diffusers.models.autoencoders.autoencoder_kl_temporal_decoder import (
Expand Down
6 changes: 1 addition & 5 deletions src/onediff/infer_compiler/transform/builtin_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import importlib
import types
import inspect
from functools import singledispatch, partial
from collections import OrderedDict
from collections.abc import Iterable
Expand Down Expand Up @@ -160,10 +161,6 @@ def _(mod: type):


def default_converter(obj, verbose=False, *, proxy_cls=None):
# Higher versions of diffusers might use torch.nn.modules.Linear
if obj is torch.nn.Linear:
return flow.nn.Linear

if not is_need_mock(type(obj)):
return obj
try:
Expand Down Expand Up @@ -211,7 +208,6 @@ def init(self):
attr = getattr(proxy_md, k)
try:
self.__dict__[k] = torch2oflow(attr)

except Exception as e:
logger.error(f"convert {type(attr)} failed: {e}")
raise NotImplementedError(f"Unsupported type: {type(attr)}")
Expand Down
20 changes: 12 additions & 8 deletions src/onediff/infer_compiler/transform/custom_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A module for registering custom torch2oflow functions and classes."""
import inspect
import importlib.util
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from ..import_tools import import_module_from_path
Expand Down Expand Up @@ -50,16 +51,19 @@ def import_module_safely(module_path, module_name):
# compiler_registry_path
registry_path = Path(__file__).parents[3] / "infer_compiler_registry"

import_module_safely(registry_path / "register_diffusers", "register_diffusers")
if importlib.util.find_spec("diffusers") is not None:
import_module_safely(registry_path / "register_diffusers", "register_diffusers")

import_module_safely(
registry_path / "register_onediff_quant", "register_onediff_quant"
)
if importlib.util.find_spec("onediff_quant") is not None:
import_module_safely(
registry_path / "register_onediff_quant", "register_onediff_quant"
)

import_module_safely(
registry_path / "register_diffusers_enterprise_lite",
"register_diffusers_enterprise_lite",
)
if importlib.util.find_spec("diffusers_enterprise_lite") is not None:
import_module_safely(
registry_path / "register_diffusers_enterprise_lite",
"register_diffusers_enterprise_lite",
)


def ensure_list(obj):
Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/transform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, debug_mode=False, tmp_dir="./output"):

def _setup_logger(self):
name = "ONEDIFF"
level = logging.DEBUG if self.debug_mode else logging.ERROR
level = logging.DEBUG if self.debug_mode else logging.WARNING
logger.configure_logging(name=name, file_name=None, level=level, log_dir=None)
self.logger = logger

Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def configure_logging(self, name, level, log_dir=None, file_name=None):

# Create a console formatter and add it to a console handler
console_formatter = ColorFormatter(
fmt="%(levelname)s [%(asctime)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s [%(asctime)s] %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
Expand Down

0 comments on commit 4da2270

Please sign in to comment.