diff --git a/examples/text_to_image_sdxl.py b/examples/text_to_image_sdxl.py index 67b194132..08b2bf483 100644 --- a/examples/text_to_image_sdxl.py +++ b/examples/text_to_image_sdxl.py @@ -108,3 +108,16 @@ num_inference_steps=args.n_steps, output_type=OUTPUT_TYPE, ).images + + +print("Test run with other another uncommon resolution...") +if args.run_multiple_resolutions: + h = 544 + w = 408 + image = base( + prompt=args.prompt, + height=h, + width=w, + num_inference_steps=args.n_steps, + output_type=OUTPUT_TYPE, + ).images diff --git a/src/infer_compiler_registry/register_diffusers/__init__.py b/src/infer_compiler_registry/register_diffusers/__init__.py index 923b1f3e2..ddcc14756 100644 --- a/src/infer_compiler_registry/register_diffusers/__init__.py +++ b/src/infer_compiler_registry/register_diffusers/__init__.py @@ -10,6 +10,10 @@ from diffusers.models.attention_processor import LoRAAttnProcessor2_0 from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.transformer_2d import Transformer2DModel +if diffusers_version >= version.parse("0.25.00"): + from diffusers.models.upsampling import Upsample2D +else: + from diffusers.models.resnet import Upsample2D if diffusers_version >= version.parse("0.24.00"): from diffusers.models.resnet import SpatioTemporalResBlock from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel @@ -26,6 +30,7 @@ from .attention_processor_oflow import LoRAAttnProcessor2_0 as LoRAAttnProcessorOflow from .unet_2d_condition_oflow import UNet2DConditionModel as UNet2DConditionModelOflow from .transformer_2d_oflow import Transformer2DModel as Transformer2DModelOflow +from .unet_2d_blocks_oflow import Upsample2D as Upsample2DOflow from .spatio_temporal_oflow import ( SpatioTemporalResBlock as SpatioTemporalResBlockOflow, ) @@ -57,5 +62,6 @@ torch2oflow_class_map.update({Transformer2DModel: Transformer2DModelOflow}) torch2oflow_class_map.update({UNet2DConditionModel: UNet2DConditionModelOflow}) +torch2oflow_class_map.update({Upsample2D: Upsample2DOflow}) register(torch2oflow_class_map=torch2oflow_class_map) diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py new file mode 100644 index 000000000..d9b0f29c5 --- /dev/null +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py @@ -0,0 +1,106 @@ +import oneflow as torch +import oneflow.nn.functional as F +from oneflow import nn + +import importlib.metadata +from packaging import version +from typing import Any, Dict, List, Optional, Tuple, Union + +from onediff.infer_compiler.transform import transform_mgr + +transformed_diffusers = transform_mgr.transform_package("diffusers") +diffusers_version = version.parse(importlib.metadata.version("diffusers")) + +LoRACompatibleConv = transformed_diffusers.models.lora.LoRACompatibleConv + +try: + USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND +except Exception as e: + USE_PEFT_BACKEND = False + + +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + """ + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if diffusers_version >= version.parse("0.25.00"): + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( + 0, 3, 1, 2 + ) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + if diffusers_version >= version.parse("0.25.00"): + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + else: + # Rewritten for the switching of uncommon resolutions. + # hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + hidden_states = F.interpolate_like( + hidden_states, like=output_size, mode="nearest" + ) + else: + if output_size is None: + hidden_states = F.interpolate( + hidden_states, scale_factor=2.0, mode="nearest" + ) + else: + # Rewritten for the switching of uncommon resolutions. + # hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + hidden_states = F.interpolate_like( + hidden_states, like=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if diffusers_version < version.parse("0.22.0"): + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + else: + if self.name == "conv": + if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + if ( + isinstance(self.Conv2d_0, LoRACompatibleConv) + and not USE_PEFT_BACKEND + ): + hidden_states = self.Conv2d_0(hidden_states, scale) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py index 88c57e8ee..c37a64d68 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py @@ -100,7 +100,9 @@ def forward( default_overall_up_factor = 2 ** self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False + # forward_upsample_size = False + # interpolate through upsample_size + forward_upsample_size = True upsample_size = None for dim in sample.shape[-2:]: @@ -431,7 +433,9 @@ def forward( # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + # To support dynamic switching of special resolutions, pass a like tensor. + # upsample_size = down_block_res_samples[-1].shape[2:] + upsample_size = down_block_res_samples[-1] if ( hasattr(upsample_block, "has_cross_attention")