Skip to content

Commit

Permalink
Support SVD dynamic shape[feat] (#564)
Browse files Browse the repository at this point in the history
This PR is done:

issue: siliconflow/sd-team#240

- [x] SVD acceleration supports dynamic resolution switching. 
- [x] Add example.

Run:
```
export ONEFLOW_RUN_GRAPH_BY_VM="1"
python3 examples/image_to_video.py --variant fp16 --output-video vm.mp4 --decode-chunk-size 4 --run_multiple_resolutions true
```
Output:
```
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.50it/s]
warmup:
100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [02:29<00:00,  5.99s/it]
infer:
100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [00:28<00:00,  1.15s/it]
test dynamic resolution switch:
100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [00:28<00:00,  1.15s/it]
```
<img width="1024" alt="image"
src="https://github.com/siliconflow/onediff/assets/54010254/7204f043-1611-43d5-83ba-ced69482f3be">
  • Loading branch information
lixiang007666 authored Jan 26, 2024
1 parent ea736c8 commit 07184c5
Show file tree
Hide file tree
Showing 3 changed files with 882 additions and 2 deletions.
27 changes: 26 additions & 1 deletion benchmarks/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
import argparse
import time
import json
import random
from PIL import Image, ImageDraw

import oneflow as flow
import torch

from diffusers.utils import load_image, export_to_video
import oneflow as flow
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.utils import set_boolean_env_var

Expand Down Expand Up @@ -73,6 +76,11 @@ def parse_args():
type=int,
default=ATTENTION_FP16_SCORE_ACCUM_MAX_M,
)
parser.add_argument(
"--run_multiple_resolutions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
)
return parser.parse_args()


Expand Down Expand Up @@ -189,6 +197,11 @@ def callback_on_step_end(self, pipe, i, t, callback_kwargs={}):
return callback_kwargs


def adjust_resolution():
resolutions = [(1024, 512),(768, 576)]
return random.choice(resolutions)


def main():
args = parse_args()
if args.deepcache:
Expand Down Expand Up @@ -229,6 +242,11 @@ def main():
input_image = load_image(args.input_image)
input_image.resize((width, height), Image.LANCZOS)

if args.run_multiple_resolutions:
new_width, new_height = adjust_resolution()
otherResImg = load_image(args.input_image)
otherResImg.resize((new_width, new_height), Image.LANCZOS)

if args.control_image is None:
if args.controlnet is None:
control_image = None
Expand Down Expand Up @@ -301,6 +319,13 @@ def get_kwarg_inputs():
print(f"CUDA Mem after: {cuda_mem_after_used / 1024:.3f}GiB")
print(f"Host Mem after: {host_mem_after_used / 1024:.3f}GiB")

if args.run_multiple_resolutions:
kwarg_inputs['image'] = otherResImg
kwarg_inputs['height'] = new_height
kwarg_inputs['width'] = new_width
print("Test run with multiple resolutions...")
output_frames = pipe(**kwarg_inputs).frames

if args.output_video is not None:
export_to_video(output_frames[0], args.output_video, fps=args.fps)
else:
Expand Down
9 changes: 9 additions & 0 deletions src/infer_compiler_registry/register_diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from diffusers.models.transformer_2d import Transformer2DModel
if diffusers_version >= version.parse("0.24.00"):
from diffusers.models.resnet import SpatioTemporalResBlock
from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel
from diffusers.models.attention import TemporalBasicTransformerBlock
from diffusers.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel

if diffusers_version >= version.parse("0.25.00"):
from diffusers.models.autoencoders.autoencoder_kl_temporal_decoder import TemporalDecoder
Expand All @@ -27,6 +30,9 @@
SpatioTemporalResBlock as SpatioTemporalResBlockOflow,
)
from .spatio_temporal_oflow import TemporalDecoder as TemporalDecoderOflow
from .spatio_temporal_oflow import TransformerSpatioTemporalModel as TransformerSpatioTemporalModelOflow
from .spatio_temporal_oflow import TemporalBasicTransformerBlock as TemporalBasicTransformerBlockOflow
from .spatio_temporal_oflow import UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModelOflow

# For CI
if diffusers_version >= version.parse("0.24.00"):
Expand All @@ -37,6 +43,9 @@
LoRAAttnProcessor2_0: LoRAAttnProcessorOflow,
SpatioTemporalResBlock: SpatioTemporalResBlockOflow,
TemporalDecoder: TemporalDecoderOflow,
TransformerSpatioTemporalModel: TransformerSpatioTemporalModelOflow,
TemporalBasicTransformerBlock: TemporalBasicTransformerBlockOflow,
UNetSpatioTemporalConditionModel: UNetSpatioTemporalConditionModelOflow,
}
else:
torch2oflow_class_map = {
Expand Down
Loading

0 comments on commit 07184c5

Please sign in to comment.