Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] deepcache uncommon resolution base #620

Merged
merged 6 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def forward(
default_overall_up_factor = 2 ** self.unet_module.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
# forward_upsample_size = False
# interpolate through upsample_size
forward_upsample_size = True
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
Expand Down Expand Up @@ -341,7 +343,9 @@ def forward(
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
# To support dynamic switching of special resolutions, pass a like tensor.
# upsample_size = down_block_res_samples[-1].shape[2:]
upsample_size = down_block_res_samples[-1]

if (
hasattr(upsample_block, "has_cross_attention")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
KDownsample2D,
KUpsample2D,
ResnetBlock2D,
Upsample2D,
)
from diffusers.models.transformer_2d import Transformer2DModel

Expand Down Expand Up @@ -69,6 +68,12 @@
KCrossAttnUpBlock2D,
)

LoRACompatibleConv = diffusers.models.lora.LoRACompatibleConv

try:
USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND
except Exception as e:
USE_PEFT_BACKEND = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -509,6 +514,130 @@ def get_up_block(
raise ValueError(f"{up_block_type} does not exist.")


class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.

Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""

def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
norm_type=None,
eps=None,
elementwise_affine=None,
bias=True,
interpolate=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv

if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
elif norm_type == "rms_norm":
self.norm = RMSNorm(channels, eps, elementwise_affine)
elif norm_type is None:
self.norm = None
else:
raise ValueError(f"unknown norm_type: {norm_type}")

conv = None
if use_conv_transpose:
if kernel_size is None:
kernel_size = 4
conv = nn.ConvTranspose2d(
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
)
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv

def forward(
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels

if self.norm is not None:
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

if self.use_conv_transpose:
return self.conv(hidden_states)

# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)

# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()

# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if self.interpolate:
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
# Rewritten for the switching of uncommon resolutions.
# hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
hidden_states = F.interpolate_like(
hidden_states, like=output_size, mode="nearest"
)

# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)

return hidden_states


class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,10 @@ def forward(
default_overall_up_factor = 2 ** self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
# forward_upsample_size = False
# interpolate through upsample_size
forward_upsample_size = True

upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
Expand Down Expand Up @@ -1179,7 +1182,9 @@ def forward(
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
# To support dynamic switching of special resolutions, pass a like tensor.
# upsample_size = down_block_res_samples[-1].shape[2:]
upsample_size = down_block_res_samples[-1]

if (
hasattr(upsample_block, "has_cross_attention")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
ResnetBlock2D,
SpatioTemporalResBlock,
TemporalConvLayer,
Upsample2D,
)
from .unet_2d_blocks import Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.models.transformer_temporal import (
TransformerSpatioTemporalModel,
Expand Down
Loading