From 15bf048fbf2480bc955d288508ca172e4fbfc915 Mon Sep 17 00:00:00 2001 From: lixiang007666 <88304454@qq.com> Date: Sun, 4 Feb 2024 16:15:19 +0800 Subject: [PATCH 1/2] [Fix] deepcache uncommon resolution base --- .../models/fast_unet_2d_condition.py | 8 +- .../deep_cache/models/unet_2d_blocks.py | 131 +++++++++++++++++- .../deep_cache/models/unet_2d_condition.py | 9 +- 3 files changed, 143 insertions(+), 5 deletions(-) diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py index 57c938542..aa210bc0c 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py @@ -66,7 +66,9 @@ def forward( default_overall_up_factor = 2 ** self.unet_module.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False + # forward_upsample_size = False + # interpolate through upsample_size + forward_upsample_size = True upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): @@ -340,7 +342,9 @@ def forward( # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + # To support dynamic switching of special resolutions, pass a like tensor. + # upsample_size = down_block_res_samples[-1].shape[2:] + upsample_size = down_block_res_samples[-1] if ( hasattr(upsample_block, "has_cross_attention") diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py index 9ee49a1b4..2d09c41b8 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py @@ -39,7 +39,6 @@ KDownsample2D, KUpsample2D, ResnetBlock2D, - Upsample2D, ) from diffusers.models.transformer_2d import Transformer2DModel @@ -65,6 +64,12 @@ KCrossAttnUpBlock2D, ) +LoRACompatibleConv = diffusers.models.lora.LoRACompatibleConv + +try: + USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND +except Exception as e: + USE_PEFT_BACKEND = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -490,6 +495,130 @@ def get_up_block( raise ValueError(f"{up_block_type} does not exist.") +class Upsample2D(nn.Module): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: Optional[int] = None, + name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + # Rewritten for the switching of uncommon resolutions. + # hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + hidden_states = F.interpolate_like( + hidden_states, like=output_size, mode="nearest" + ) + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + if self.use_conv: + if self.name == "conv": + if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) + else: + if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: + hidden_states = self.Conv2d_0(hidden_states, scale) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + class CrossAttnDownBlock2D(nn.Module): def __init__( self, diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py index 4ef05dbb5..4848841fc 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py @@ -789,7 +789,10 @@ def forward( default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False + # forward_upsample_size = False + # interpolate through upsample_size + forward_upsample_size = True + upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): @@ -1033,7 +1036,9 @@ def forward( # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + # To support dynamic switching of special resolutions, pass a like tensor. + # upsample_size = down_block_res_samples[-1].shape[2:] + upsample_size = down_block_res_samples[-1] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample, current_record_f = upsample_block( From 77c3ff0f09c549fb02abce7136602b24762f4ff5 Mon Sep 17 00:00:00 2001 From: lixiang007666 <88304454@qq.com> Date: Sun, 4 Feb 2024 16:26:16 +0800 Subject: [PATCH 2/2] Refne --- .../onediffx/deep_cache/models/unet_3d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py index 0e5a124f1..553ad8577 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py @@ -12,8 +12,8 @@ ResnetBlock2D, SpatioTemporalResBlock, TemporalConvLayer, - Upsample2D, ) +from .unet_2d_blocks import Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.transformer_temporal import ( TransformerSpatioTemporalModel,