Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace reshape by flatten and unflatten to speedup svd #516

Merged
merged 5 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 95 additions & 84 deletions examples/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Run with ONEFLOW_RUN_GRAPH_BY_VM=1 to get faster
MODEL = 'stabilityai/stable-video-diffusion-img2vid-xt'
MODEL = "stabilityai/stable-video-diffusion-img2vid-xt"
VARIANT = None
CUSTOM_PIPELINE = None
SCHEDULER = None
Expand All @@ -14,7 +14,7 @@
WIDTH = 1024
FPS = 7
DECODE_CHUNK_SIZE = 4
INPUT_IMAGE = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true'
INPUT_IMAGE = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true"
EXTRA_CALL_KWARGS = None
ATTENTION_FP16_SCORE_ACCUM_MAX_M = 0

Expand All @@ -23,7 +23,7 @@
import argparse
import time
import json
from PIL import (Image, ImageDraw)
from PIL import Image, ImageDraw

import oneflow as flow
import torch
Expand All @@ -35,68 +35,71 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default=MODEL)
parser.add_argument('--variant', type=str, default=VARIANT)
parser.add_argument('--custom-pipeline', type=str, default=CUSTOM_PIPELINE)
parser.add_argument('--scheduler', type=str, default=SCHEDULER)
parser.add_argument('--lora', type=str, default=LORA)
parser.add_argument('--controlnet', type=str, default=None)
parser.add_argument('--steps', type=int, default=STEPS)
parser.add_argument('--seed', type=int, default=SEED)
parser.add_argument('--warmups', type=int, default=WARMUPS)
parser.add_argument('--frames', type=int, default=FRAMES)
parser.add_argument('--batch', type=int, default=BATCH)
parser.add_argument('--height', type=int, default=HEIGHT)
parser.add_argument('--width', type=int, default=WIDTH)
parser.add_argument('--fps', type=int, default=FPS)
parser.add_argument('--decode-chunk-size',
type=int,
default=DECODE_CHUNK_SIZE)
parser.add_argument('--extra-call-kwargs',
type=str,
default=EXTRA_CALL_KWARGS)
parser.add_argument('--input-image', type=str, default=INPUT_IMAGE)
parser.add_argument('--control-image', type=str, default=None)
parser.add_argument('--output-video', type=str, default=None)
parser.add_argument('--compiler',
type=str,
default='oneflow',
choices=['none', 'oneflow', 'compile'])
parser.add_argument('--attention-fp16-score-accum-max-m',
type=int,
default=ATTENTION_FP16_SCORE_ACCUM_MAX_M)
parser.add_argument("--model", type=str, default=MODEL)
parser.add_argument("--variant", type=str, default=VARIANT)
parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
parser.add_argument("--scheduler", type=str, default=SCHEDULER)
parser.add_argument("--lora", type=str, default=LORA)
parser.add_argument("--controlnet", type=str, default=None)
parser.add_argument("--steps", type=int, default=STEPS)
parser.add_argument("--seed", type=int, default=SEED)
parser.add_argument("--warmups", type=int, default=WARMUPS)
parser.add_argument("--frames", type=int, default=FRAMES)
parser.add_argument("--batch", type=int, default=BATCH)
parser.add_argument("--height", type=int, default=HEIGHT)
parser.add_argument("--width", type=int, default=WIDTH)
parser.add_argument("--fps", type=int, default=FPS)
parser.add_argument("--decode-chunk-size", type=int, default=DECODE_CHUNK_SIZE)
parser.add_argument("--extra-call-kwargs", type=str, default=EXTRA_CALL_KWARGS)
parser.add_argument("--input-image", type=str, default=INPUT_IMAGE)
parser.add_argument("--control-image", type=str, default=None)
parser.add_argument("--output-video", type=str, default=None)
parser.add_argument(
"--compiler",
type=str,
default="oneflow",
choices=["none", "oneflow", "compile"],
)
parser.add_argument(
"--attention-fp16-score-accum-max-m",
type=int,
default=ATTENTION_FP16_SCORE_ACCUM_MAX_M,
)
return parser.parse_args()


def load_model(pipeline_cls,
model_name,
variant=None,
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None):
def load_model(
pipeline_cls,
model_name,
variant=None,
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None,
):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs['custom_pipeline'] = custom_pipeline
extra_kwargs["custom_pipeline"] = custom_pipeline
if variant is not None:
extra_kwargs['variant'] = variant
extra_kwargs["variant"] = variant
if controlnet is not None:
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(controlnet,
torch_dtype=torch.float16)
extra_kwargs['controlnet'] = controlnet
model = pipeline_cls.from_pretrained(model_name,
torch_dtype=torch.float16,
**extra_kwargs)

controlnet = ControlNetModel.from_pretrained(
controlnet, torch_dtype=torch.float16
)
extra_kwargs["controlnet"] = controlnet
model = pipeline_cls.from_pretrained(
model_name, torch_dtype=torch.float16, **extra_kwargs
)
if scheduler is not None:
scheduler_cls = getattr(importlib.import_module('diffusers'),
scheduler)
scheduler_cls = getattr(importlib.import_module("diffusers"), scheduler)
model.scheduler = scheduler_cls.from_config(model.scheduler.config)
if lora is not None:
model.load_lora_weights(lora)
model.fuse_lora()
model.safety_checker = None
model.to(torch.device('cuda'))
model.to(torch.device("cuda"))
return model


Expand All @@ -118,8 +121,9 @@ def compile_model(model, attention_fp16_score_accum_max_m=-1):
# | True | 0 | OK | 30.947s |
# | True | 2304 | OK | 30.820s |
set_boolean_env_var(
'ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M',
attention_fp16_score_accum_max_m)
"ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M",
attention_fp16_score_accum_max_m,
)

model.image_encoder = oneflow_compile(model.image_encoder)
model.unet = oneflow_compile(model.unet)
Expand All @@ -129,7 +133,6 @@ def compile_model(model, attention_fp16_score_accum_max_m=-1):


class IterationProfiler:

def __init__(self):
self.begin = None
self.end = None
Expand Down Expand Up @@ -169,19 +172,20 @@ def main():
controlnet=args.controlnet,
)

if args.compiler == 'none':
if args.compiler == "none":
pass
elif args.compiler == 'oneflow':
model = compile_model(model,
attention_fp16_score_accum_max_m=args.
attention_fp16_score_accum_max_m)
elif args.compiler == 'compile':
elif args.compiler == "oneflow":
model = compile_model(
model,
attention_fp16_score_accum_max_m=args.attention_fp16_score_accum_max_m,
)
elif args.compiler == "compile":
model.unet = torch.compile(model.unet)
if hasattr(model, 'controlnet'):
if hasattr(model, "controlnet"):
model.controlnet = torch.compile(model.controlnet)
# model.vae = torch.compile(model.vae)
else:
raise ValueError(f'Unknown compiler: {args.compiler}')
raise ValueError(f"Unknown compiler: {args.compiler}")

input_image = load_image(args.input_image)
input_image.resize((args.width, args.height), Image.LANCZOS)
Expand All @@ -190,16 +194,21 @@ def main():
if args.controlnet is None:
control_image = None
else:
control_image = Image.new('RGB', (args.width, args.height))
control_image = Image.new("RGB", (args.width, args.height))
draw = ImageDraw.Draw(control_image)
draw.ellipse((args.width // 4, args.height // 4,
args.width // 4 * 3, args.height // 4 * 3),
fill=(255, 255, 255))
draw.ellipse(
(
args.width // 4,
args.height // 4,
args.width // 4 * 3,
args.height // 4 * 3,
),
fill=(255, 255, 255),
)
del draw
else:
control_image = Image.open(args.control_image).convert('RGB')
control_image = control_image.resize((args.width, args.height),
Image.LANCZOS)
control_image = Image.open(args.control_image).convert("RGB")
control_image = control_image.resize((args.width, args.height), Image.LANCZOS)

def get_kwarg_inputs():
kwarg_inputs = dict(
Expand All @@ -211,43 +220,45 @@ def get_kwarg_inputs():
num_frames=args.frames,
fps=args.fps,
decode_chunk_size=args.decode_chunk_size,
generator=None if args.seed is None else torch.Generator(
device='cuda').manual_seed(args.seed),
**(dict() if args.extra_call_kwargs is None else json.loads(
args.extra_call_kwargs)),
**(
dict()
if args.extra_call_kwargs is None
else json.loads(args.extra_call_kwargs)
),
)
if control_image is not None:
kwarg_inputs['control_image'] = control_image
kwarg_inputs["control_image"] = control_image
return kwarg_inputs

if args.warmups > 0:
print('Begin warmup')
print("Begin warmup")
for _ in range(args.warmups):
model(**get_kwarg_inputs())
print('End warmup')
print("End warmup")

kwarg_inputs = get_kwarg_inputs()
iter_profiler = None
if 'callback_on_step_end' in inspect.signature(model).parameters:
if "callback_on_step_end" in inspect.signature(model).parameters:
iter_profiler = IterationProfiler()
kwarg_inputs[
'callback_on_step_end'] = iter_profiler.callback_on_step_end
kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end
if args.seed is not None:
torch.manual_seed(args.seed)
begin = time.time()
output_frames = model(**kwarg_inputs).frames
end = time.time()

print(f'Inference time: {end - begin:.3f}s')
print(f"Inference time: {end - begin:.3f}s")
iter_per_sec = iter_profiler.get_iter_per_sec()
if iter_per_sec is not None:
print(f'Iterations per second: {iter_per_sec:.3f}')
print(f"Iterations per second: {iter_per_sec:.3f}")
peak_mem = torch.cuda.max_memory_allocated()
print(f'Peak memory: {peak_mem / 1024**3:.3f}GiB')
print(f"Peak memory: {peak_mem / 1024**3:.3f}GiB")

if args.output_video is not None:
export_to_video(output_frames[0], args.output_video, fps=args.fps)
else:
print('Please set `--output-video` to save the output-video')
print("Please set `--output-video` to save the output-video")


if __name__ == '__main__':
if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,12 @@ def custom_forward(*inputs):
batch_size = batch_frames // num_frames
# sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
# Dynamic shape for VAE divide chunks
sample = (
sample[None, :]
.reshape(batch_size, -1, channels, height, width)
.permute(0, 2, 1, 3, 4)
)
sample = sample.unflatten(0, shape=(batch_size, -1)).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)

# sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
# Dynamic shape for VAE divide chunks
sample = sample.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1)

return sample

Expand Down Expand Up @@ -233,19 +229,16 @@ def forward(
# )
#
# Dynamic shape for VAE divide chunks
hidden_states_mix = (
hidden_states[None, :]
.reshape(batch_size, -1, channels, height, width)
.permute(0, 2, 1, 3, 4)
hidden_states_mix = hidden_states.unflatten(0, shape=(batch_size, -1)).permute(
0, 2, 1, 3, 4
)
hidden_states = (
hidden_states[None, :]
.reshape(batch_size, -1, channels, height, width)
.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states.unflatten(0, shape=(batch_size, -1)).permute(
0, 2, 1, 3, 4
)

if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
# temb = temb.reshape(batch_size, num_frames, -1)
temb = temb.unflatten(0, shape=(batch_size, -1))

hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
Expand All @@ -256,7 +249,5 @@ def forward(

# hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
# Dynamic shape for VAE divide chunks
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
-1, channels, height, width
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
return hidden_states
Loading