Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix compatibility with peft and diffusers 0.26.1 #626

Merged
merged 7 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/infer_compiler_registry/register_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
from diffusers.models.attention import TemporalBasicTransformerBlock
from diffusers.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
if diffusers_version >= version.parse("0.26.00"):
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
else:
from diffusers.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel

if diffusers_version >= version.parse("0.25.00"):
from diffusers.models.autoencoders.autoencoder_kl_temporal_decoder import TemporalDecoder
Expand Down
6 changes: 1 addition & 5 deletions src/onediff/infer_compiler/transform/builtin_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import importlib
import types
import inspect
from functools import singledispatch, partial
from collections import OrderedDict
from collections.abc import Iterable
Expand Down Expand Up @@ -160,10 +161,6 @@ def _(mod: type):


def default_converter(obj, verbose=False, *, proxy_cls=None):
# Higher versions of diffusers might use torch.nn.modules.Linear
if obj is torch.nn.Linear:
return flow.nn.Linear

if not is_need_mock(type(obj)):
return obj
try:
Expand Down Expand Up @@ -211,7 +208,6 @@ def init(self):
attr = getattr(proxy_md, k)
try:
self.__dict__[k] = torch2oflow(attr)

except Exception as e:
logger.error(f"convert {type(attr)} failed: {e}")
raise NotImplementedError(f"Unsupported type: {type(attr)}")
Expand Down
20 changes: 12 additions & 8 deletions src/onediff/infer_compiler/transform/custom_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A module for registering custom torch2oflow functions and classes."""
import inspect
import importlib.util
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from ..import_tools import import_module_from_path
Expand Down Expand Up @@ -50,16 +51,19 @@ def import_module_safely(module_path, module_name):
# compiler_registry_path
registry_path = Path(__file__).parents[3] / "infer_compiler_registry"

import_module_safely(registry_path / "register_diffusers", "register_diffusers")
if importlib.util.find_spec("diffusers") is not None:
import_module_safely(registry_path / "register_diffusers", "register_diffusers")

import_module_safely(
registry_path / "register_onediff_quant", "register_onediff_quant"
)
if importlib.util.find_spec("onediff_quant") is not None:
import_module_safely(
registry_path / "register_onediff_quant", "register_onediff_quant"
)

import_module_safely(
registry_path / "register_diffusers_enterprise_lite",
"register_diffusers_enterprise_lite",
)
if importlib.util.find_spec("diffusers_enterprise_lite") is not None:
import_module_safely(
registry_path / "register_diffusers_enterprise_lite",
"register_diffusers_enterprise_lite",
)


def ensure_list(obj):
Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/transform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, debug_mode=False, tmp_dir="./output"):

def _setup_logger(self):
name = "ONEDIFF"
level = logging.DEBUG if self.debug_mode else logging.ERROR
level = logging.DEBUG if self.debug_mode else logging.WARNING
logger.configure_logging(name=name, file_name=None, level=level, log_dir=None)
self.logger = logger

Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def configure_logging(self, name, level, log_dir=None, file_name=None):

# Create a console formatter and add it to a console handler
console_formatter = ColorFormatter(
fmt="%(levelname)s [%(asctime)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s [%(asctime)s] %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

filename 为None 应该没必要加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有看打印出来的日志,filename不为none,还是有意义的

)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
Expand Down
Loading