Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support stable cascade #659

Merged
merged 15 commits into from
Mar 12, 2024
Merged
6 changes: 2 additions & 4 deletions benchmarks/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def main():
pipe.unet = torch.compile(pipe.unet)
if hasattr(pipe, "controlnet"):
pipe.controlnet = torch.compile(pipe.controlnet)
# model.vae = torch.compile(model.vae)
pipe.vae = torch.compile(pipe.vae)
else:
raise ValueError(f"Unknown compiler: {args.compiler}")

Expand Down Expand Up @@ -249,12 +249,10 @@ def get_kwarg_inputs():
print("End warmup")

kwarg_inputs = get_kwarg_inputs()
iter_profiler = None
iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
iter_profiler = IterationProfiler()
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
elif "callback" in inspect.signature(pipe).parameters:
iter_profiler = IterationProfiler()
kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
begin = time.time()
output_frames = pipe(**kwarg_inputs).frames
Expand Down
29 changes: 9 additions & 20 deletions benchmarks/instant_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,16 @@ def load_pipe(
controlnet, torch_dtype=torch.float16,
)
extra_kwargs["controlnet"] = controlnet
is_quantized_model = False
if os.path.exists(os.path.join(model_name, "calibrate_info.txt")):
is_quantized_model = True
from onediff.quantization import setup_onediff_quant
from onediff.quantization import QuantPipeline

setup_onediff_quant()
pipe = pipeline_cls.from_pretrained(
model_name, torch_dtype=torch.float16, **extra_kwargs
)
pipe = QuantPipeline.from_pretrained(
pipeline_cls, model_name, torch_dtype=torch.float16, **extra_kwargs
)
else:
pipe = pipeline_cls.from_pretrained(
model_name, torch_dtype=torch.float16, **extra_kwargs
)
if scheduler is not None:
scheduler_cls = getattr(importlib.import_module("diffusers"), scheduler)
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
Expand All @@ -115,14 +116,6 @@ def load_pipe(
pipe.fuse_lora()
pipe.safety_checker = None
pipe.to(torch.device("cuda"))

# Replace quantizable modules by QuantModule.
if is_quantized_model:
from onediff.quantization import load_calibration_and_quantize_pipeline

load_calibration_and_quantize_pipeline(
os.path.join(model_name, "calibrate_info.txt"), pipe
)
return pipe


Expand Down Expand Up @@ -220,8 +213,6 @@ def main():

pipe.load_ip_adapter_instantid(face_adapter)

height = args.height
width = args.width
height = args.height or pipe.unet.config.sample_size * pipe.vae_scale_factor
width = args.width or pipe.unet.config.sample_size * pipe.vae_scale_factor

Expand Down Expand Up @@ -289,12 +280,10 @@ def get_kwarg_inputs():
# Let"s see it!
# Note: Progress bar might work incorrectly due to the async nature of CUDA.
kwarg_inputs = get_kwarg_inputs()
iter_profiler = None
iter_profiler = IterationProfiler()
if "callback_on_step_end" in inspect.signature(pipe).parameters:
iter_profiler = IterationProfiler()
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
elif "callback" in inspect.signature(pipe).parameters:
iter_profiler = IterationProfiler()
kwarg_inputs["callback"] = iter_profiler.callback_on_step_end
begin = time.time()
output_images = pipe(**kwarg_inputs).images
Expand Down
157 changes: 157 additions & 0 deletions benchmarks/patch_stable_cascade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common import (
AttnBlock,
TimestepBlock,
WuerstchenLayerNorm,
)
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_diffnext import ResBlockStageB


def patch_prior_fp16_overflow(prior, num_overflow_up_blocks=1):
for i in range(len(prior.up_blocks) - num_overflow_up_blocks, len(prior.up_blocks)):
prior.up_blocks[i].to(torch.bfloat16)
prior.up_upscalers[i].to(torch.bfloat16)
prior.up_repeat_mappers[i].to(torch.bfloat16)

def _up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
if i == len(self.up_blocks) - num_overflow_up_blocks:
x = x.to(torch.bfloat16)
r_embed = r_embed.to(torch.bfloat16)
clip = clip.to(torch.bfloat16)
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
block_class = block.__class__
if isinstance(block, ResBlockStageB):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (
x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)
):
x = torch.nn.functional.interpolate(
x.float(),
skip.shape[-2:],
mode="bilinear",
align_corners=True,
)
if (
skip is not None
and i >= len(level_outputs) - num_overflow_up_blocks
):
skip = skip.to(torch.bfloat16)
x = block(x, skip)
elif isinstance(block, AttnBlock):
x = block(x, clip)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x

prior._up_decode = _up_decode.__get__(prior)

prior.clf.to(torch.bfloat16)
prior_clf_forward = prior.clf.forward

def new_prior_clf_forward(x):
return prior_clf_forward(x).to(torch.float16)

prior.clf.forward = new_prior_clf_forward

return prior


original_pixel_shuffle = torch.nn.functional.pixel_shuffle


def pixel_shuffle(input, upscale_factor):
# https://blog.csdn.net/ONE_SIX_MIX/article/details/103757856
# Thanks ChatGPT: https://chat.openai.com/share/1ba80104-406f-4bb8-8292-105f69e0452e

assert input.dim() >= 3, "Input tensor must have at least 3 dimensions"

# Separate batch_dims and original C, H, W dimensions
*batch_dims, channels, height, width = input.shape

assert (
channels % (upscale_factor ** 2) == 0
), "Number of channels must be divisible by the square of the upscale factor"

# Calculate new channels after applying upscale_factor
new_channels = channels // (upscale_factor ** 2)

# Reshape input to (*batch_dims, new_channels, upscale_factor, upscale_factor, height, width)
reshaped = input.reshape(
*batch_dims, new_channels, upscale_factor, upscale_factor, height, width
)

# Adjust permute to handle dynamic batch dimensions
permute_dims = list(range(len(batch_dims))) + [
len(batch_dims),
len(batch_dims) + 3,
len(batch_dims) + 1,
len(batch_dims) + 4,
len(batch_dims) + 2,
]
permuted = reshaped.permute(*permute_dims)

# Final reshape to get to the target shape
output = permuted.reshape(
*batch_dims, new_channels, height * upscale_factor, width * upscale_factor
)

return output


def pixel_unshuffle(input, downscale_factor):
# Thanks ChatGPT: https://chat.openai.com/share/1ba80104-406f-4bb8-8292-105f69e0452e

assert input.dim() >= 3, "Input tensor must have at least 3 dimensions"

# Separate batch_dims and original C, H, W dimensions
*batch_dims, channels, height, width = input.shape

# Ensure H and W are divisible by downscale_factor
assert (
height % downscale_factor == 0 and width % downscale_factor == 0
), "Height and Width must be divisible by the downscale factor"

# Reshape
reshaped = input.reshape(
*batch_dims,
channels,
height // downscale_factor,
downscale_factor,
width // downscale_factor,
downscale_factor,
)

# Adjust permutation indices for tensors with dimensions > 4
permute_dims = list(range(len(batch_dims))) + [
len(batch_dims),
len(batch_dims) + 2,
len(batch_dims) + 4,
len(batch_dims) + 1,
len(batch_dims) + 3,
]
permuted = reshaped.permute(*permute_dims)

# Final reshape
output = permuted.reshape(
*batch_dims,
channels * downscale_factor ** 2,
height // downscale_factor,
width // downscale_factor,
)

return output


# pixel_shuffle() and pixel_unshuffle() call Tensor.sizes() which is not supported by dynamo.
def patch_torch_compile():
torch.nn.functional.pixel_shuffle = pixel_shuffle
torch.nn.functional.pixel_unshuffle = pixel_unshuffle
133 changes: 133 additions & 0 deletions benchmarks/patch_stable_cascade_of.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional
import oneflow as torch
import oneflow.nn as nn
import oneflow.nn.functional as F
from packaging import version
import importlib.metadata

from onediff.infer_compiler.transform import transform_mgr

diffusers_of = transform_mgr.transform_package("diffusers")
StableCascadeUnet_OF_CLS = (
diffusers_of.pipelines.stable_cascade.modeling_stable_cascade_common.StableCascadeUnet
)

ResBlockStageB = (
diffusers_of.pipelines.wuerstchen.modeling_wuerstchen_diffnext.ResBlockStageB
)
AttnBlock = diffusers_of.pipelines.wuerstchen.modeling_wuerstchen_common.AttnBlock
TimestepBlock = (
diffusers_of.pipelines.wuerstchen.modeling_wuerstchen_common.TimestepBlock
)

num_overflow_up_blocks = 1


class StableCascadeUnet_OF(StableCascadeUnet_OF_CLS):
def of_up_decode(self, level_outputs, r_embed, clip):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
if i == len(self.up_blocks) - num_overflow_up_blocks:
x = x.to(torch.bfloat16)
r_embed = r_embed.to(torch.bfloat16)
clip = clip.to(torch.bfloat16)
for j in range(len(repmap) + 1):
for k, block in enumerate(up_block):
block_class = block.__class__
if isinstance(block, ResBlockStageB):
skip = level_outputs[i] if k == 0 and i > 0 else None
if skip is not None and (
x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)
):
x = torch.nn.functional.interpolate(
x.float(),
skip.shape[-2:],
mode="bilinear",
align_corners=True,
)
if (
skip is not None
and i >= len(level_outputs) - num_overflow_up_blocks
):
skip = skip.to(torch.bfloat16)
x = block(x, skip)
elif isinstance(block, AttnBlock):
x = block(x, clip)
elif isinstance(block, TimestepBlock):
x = block(x, r_embed)
else:
x = block(x)
if j < len(repmap):
x = repmap[j](x)
x = upscaler(x)
return x

def forward(
self,
x,
r,
clip_text_pooled,
clip_text=None,
clip_img=None,
effnet=None,
pixels=None,
sca=None,
crp=None,
):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)

# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r)
for c in self.config.t_conds:
if c == "sca":
cond = sca
elif c == "crp":
cond = crp
else:
cond = None
t_cond = cond or torch.zeros_like(r)
r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
clip = self.gen_c_embeddings(
clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img
)

# Model Blocks
x = self.embedding(x)
if hasattr(self, "effnet_mapper") and effnet is not None:
x = x + self.effnet_mapper(
nn.functional.interpolate(
effnet, size=x.shape[-2:], mode="bilinear", align_corners=True
)
)
if hasattr(self, "pixels_mapper"):
x = x + nn.functional.interpolate(
self.pixels_mapper(pixels),
size=x.shape[-2:],
mode="bilinear",
align_corners=True,
)
level_outputs = self._down_encode(x, r_embed, clip)
x = self.of_up_decode(level_outputs, r_embed, clip)

return self.clf(x).to(torch.float16)


# diffusers.pipelines.stable_cascade.modeling_stable_cascade_common.StableCascadeUnet
from diffusers.pipelines.stable_cascade.modeling_stable_cascade_common import (
StableCascadeUnet,
)

# torch2oflow_class_map.update({StableCascadeUnet: StableCascadeUnetOflow})
from onediff.infer_compiler.transform import register
from contextlib import contextmanager


@contextmanager
def patch_oneflow_prior_fp16_overflow():
torch2oflow_class_map = {StableCascadeUnet: StableCascadeUnet_OF}
register(torch2oflow_class_map=torch2oflow_class_map)
yield
torch2oflow_class_map = {StableCascadeUnet: StableCascadeUnet_OF_CLS}
register(torch2oflow_class_map=torch2oflow_class_map)
Loading
Loading