Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance debug of transform error #645

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from onediffx.deep_cache import StableDiffusionPipeline

parser = argparse.ArgumentParser()
parser.add_argument(
"--base", type=str, default="runwayml/stable-diffusion-v1-5"
)
parser.add_argument("--base", type=str, default="runwayml/stable-diffusion-v1-5")
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@

# SDXL turbo base: AutoPipelineForText2Image
base = AutoPipelineForText2Image.from_pretrained(
args.base,
torch_dtype=torch.float16,
variant=args.variant,
use_safetensors=True,
args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True,
)
base.to("cuda")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,23 @@ def __init__(
if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
channels,
self.out_channels,
kernel_size=kernel_size,
stride=2,
padding=padding,
bias=bias,
)
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = conv_cls(
self.channels,
self.out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias,
)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
Expand All @@ -590,7 +601,9 @@ def forward(
assert hidden_states.shape[1] == self.channels

if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2
)

if self.use_conv_transpose:
return self.conv(hidden_states)
Expand All @@ -610,7 +623,9 @@ def forward(
# size and do not make use of `scale_factor=2`
if self.interpolate:
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
hidden_states = F.interpolate(
hidden_states, scale_factor=2.0, mode="nearest"
)
else:
# Rewritten for the switching of uncommon resolutions.
# hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
Expand All @@ -630,7 +645,10 @@ def forward(
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
if (
isinstance(self.Conv2d_0, LoRACompatibleConv)
and not USE_PEFT_BACKEND
):
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
Expand Down
2 changes: 1 addition & 1 deletion onediff_diffusers_extensions/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
python_requires=">=3.7.0",
install_requires=[
"transformers>=4.27.1",
"diffusers>=0.24.0,<=0.26.2",
"diffusers>=0.24.0,<=0.25.1",
"accelerate",
"torch",
"onefx",
Expand Down
2 changes: 2 additions & 0 deletions src/onediff/infer_compiler/import_tools/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from importlib.metadata import requires
from .format_utils import MockEntityNameFormatter
from ..utils.log_utils import logger

__all__ = ["import_module_from_path", "LazyMocker", "is_need_mock"]

Expand All @@ -19,6 +20,7 @@ def is_need_mock(cls) -> bool:
return True
pkgs = requires(main_pkg)
except Exception as e:
logger.info(f"Error when checking need mock of package {main_pkg}: {e}")
return True
if pkgs:
for pkg in pkgs:
Expand Down
14 changes: 10 additions & 4 deletions src/onediff/infer_compiler/transform/builtin_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def wrapper(first_param, *args, **kwargs):
nonlocal _warning_set

before = first_param.__class__.__name__
result = dispatcher(first_param, *args, **kwargs)
try:
result = dispatcher(first_param, *args, **kwargs)
except Exception as e:
raise NotImplementedError(f"Transform failed of {type(first_param)}: {e}")
after = result.__class__.__name__

description = f"{before} transformed to {after}"
Expand Down Expand Up @@ -176,7 +179,6 @@ def init(self):
return of_obj
except Exception as e:
logger.warning(f"Unsupported type: {type(obj)} {e=}")
# raise NotImplementedError(f"Unsupported type: {obj}")
return obj


Expand Down Expand Up @@ -214,7 +216,8 @@ def proxy_getattr(self, attr):

try:
return super().__getattribute__(attr)
except:
except Exception as e:
logger.warning(f"{type(self)} getattr {attr} failed: {e}")
if attr in self._modules:
return self._modules[attr]
if attr in self._parameters:
Expand Down Expand Up @@ -442,7 +445,10 @@ def _(mod: types.BuiltinFunctionType, verbose=False):
try:
if getattr(torch.nn.functional, mod.__name__) == mod:
mod_name = "oneflow.nn.functional"
except:
except Exception as e:
logger.warning(
f"warning when get {mod.__name__} in torch.nn.functional: {e}"
)
mod_name = mod.__module__.replace("torch", "oneflow")
if mod_name is not None:
m = importlib.import_module(mod_name)
Expand Down
3 changes: 2 additions & 1 deletion src/onediff/infer_compiler/transform/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,6 @@ def transform_package(self, package_name):
"huggingface_hub.inference._text_generation"
)

except ImportError:
except Exception as e:
logger.warning(f"Pydantic related warning: {e}.")
pass
3 changes: 2 additions & 1 deletion src/onediff/infer_compiler/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def configure_logging(self, name, level, log_dir=None, file_name=None):

# Create a console formatter and add it to a console handler
console_formatter = ColorFormatter(
fmt="%(levelname)s [%(asctime)s] %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s [%(asctime)s] %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
console_handler = logging.StreamHandler()
console_handler.setFormatter(console_formatter)
Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def oneflow_module(self):

logger.debug(f"Convert {type(self._torch_module)} ...")
self._oneflow_module = torch2oflow(self._torch_module)
logger.debug(f"Convert {id(self._torch_module)=} done!")
logger.debug(f"Convert {type(self._torch_module)} done!")
return self._oneflow_module

@oneflow_module.deleter
Expand Down
Loading