Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] dynamic switch for uncommon resolutions #573

Merged
merged 6 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,16 @@
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images


print("Test run with other another uncommon resolution...")
if args.run_multiple_resolutions:
h = 544
w = 408
image = base(
prompt=args.prompt,
height=h,
width=w,
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images
6 changes: 6 additions & 0 deletions src/infer_compiler_registry/register_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from diffusers.models.attention_processor import LoRAAttnProcessor2_0
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.transformer_2d import Transformer2DModel
if diffusers_version >= version.parse("0.25.00"):
from diffusers.models.upsampling import Upsample2D
else:
from diffusers.models.resnet import Upsample2D
if diffusers_version >= version.parse("0.24.00"):
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
Expand All @@ -26,6 +30,7 @@
from .attention_processor_oflow import LoRAAttnProcessor2_0 as LoRAAttnProcessorOflow
from .unet_2d_condition_oflow import UNet2DConditionModel as UNet2DConditionModelOflow
from .transformer_2d_oflow import Transformer2DModel as Transformer2DModelOflow
from .unet_2d_blocks_oflow import Upsample2D as Upsample2DOflow
from .spatio_temporal_oflow import (
SpatioTemporalResBlock as SpatioTemporalResBlockOflow,
)
Expand Down Expand Up @@ -57,5 +62,6 @@

torch2oflow_class_map.update({Transformer2DModel: Transformer2DModelOflow})
torch2oflow_class_map.update({UNet2DConditionModel: UNet2DConditionModelOflow})
torch2oflow_class_map.update({Upsample2D: Upsample2DOflow})

register(torch2oflow_class_map=torch2oflow_class_map)
106 changes: 106 additions & 0 deletions src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import oneflow as torch
import oneflow.nn.functional as F
from oneflow import nn

import importlib.metadata
from packaging import version
from typing import Any, Dict, List, Optional, Tuple, Union

from onediff.infer_compiler.transform import transform_mgr

transformed_diffusers = transform_mgr.transform_package("diffusers")
diffusers_version = version.parse(importlib.metadata.version("diffusers"))

LoRACompatibleConv = transformed_diffusers.models.lora.LoRACompatibleConv

try:
USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND
except Exception as e:
USE_PEFT_BACKEND = False


class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
"""

def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels

if diffusers_version >= version.parse("0.25.00"):
if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2
)

if self.use_conv_transpose:
return self.conv(hidden_states)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)

# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()

if diffusers_version >= version.parse("0.25.00"):
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
if output_size is None:
hidden_states = F.interpolate(
hidden_states, scale_factor=2.0, mode="nearest"
)
else:
lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved
# Rewritten for the switching of uncommon resolutions.
# hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
hidden_states = F.interpolate_like(
hidden_states, like=output_size, mode="nearest"
)
else:
if output_size is None:
hidden_states = F.interpolate(
hidden_states, scale_factor=2.0, mode="nearest"
)
else:
# Rewritten for the switching of uncommon resolutions.
# hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
hidden_states = F.interpolate_like(
hidden_states, like=output_size, mode="nearest"
)

# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if diffusers_version < version.parse("0.22.0"):
if self.name == "conv":
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
else:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if (
isinstance(self.Conv2d_0, LoRACompatibleConv)
and not USE_PEFT_BACKEND
):
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)

return hidden_states
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def forward(
default_overall_up_factor = 2 ** self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
# forward_upsample_size = False
# interpolate through upsample_size
forward_upsample_size = True
upsample_size = None

for dim in sample.shape[-2:]:
Expand Down Expand Up @@ -431,7 +433,9 @@ def forward(
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
# To support dynamic switching of special resolutions, pass a like tensor.
lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved
# upsample_size = down_block_res_samples[-1].shape[2:]
upsample_size = down_block_res_samples[-1]

if (
hasattr(upsample_block, "has_cross_attention")
Expand Down
Loading