Skip to content

Commit

Permalink
Refine SVD diffusers example for graph save/load (#836)
Browse files Browse the repository at this point in the history
This PR is done:

- [x] Refine SVD diffusers example for graph save/load options.



![image](https://github.com/siliconflow/onediff/assets/54010254/c5a95b96-5c60-44b8-842a-a51d2ffe59e4)
  • Loading branch information
lixiang007666 authored Apr 24, 2024
1 parent b9d0ae8 commit 85d2b61
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions benchmarks/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=MODEL)
parser.add_argument("--save_graph", action="store_true")
parser.add_argument("--load_graph", action="store_true")
parser.add_argument("--variant", type=str, default=VARIANT)
parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE)
parser.add_argument("--scheduler", type=str, default=SCHEDULER)
Expand Down Expand Up @@ -251,10 +253,22 @@ def get_kwarg_inputs():
return kwarg_inputs

if args.warmups > 0:
print("Begin warmup")
for _ in range(args.warmups):
pipe(**get_kwarg_inputs())
print("End warmup")
if args.load_graph:
print("Loading graphs to avoid compilation...")
start_t = time.time()
pipe.unet.load_graph("base_unet_compiled", run_warmup=True)
pipe.vae.decoder.load_graph("base_vae_compiled", run_warmup=True)
end_t = time.time()
print(f"Loading graph elapsed: {end_t - start_t} s")
print("Begin warmup")
for _ in range(args.warmups):
pipe(**get_kwarg_inputs())
print("End warmup")
else:
print("Begin warmup")
for _ in range(args.warmups):
pipe(**get_kwarg_inputs())
print("End warmup")

kwarg_inputs = get_kwarg_inputs()
iter_profiler = IterationProfiler()
Expand All @@ -280,6 +294,14 @@ def get_kwarg_inputs():
else:
print("Please set `--output-video` to save the output video")

if args.save_graph:
print("Saving graphs...")
start_t = time.time()
pipe.unet.save_graph("base_unet_compiled")
pipe.vae.decoder.save_graph("base_vae_compiled")
end_t = time.time()
print(f"save graphs elapsed: {end_t - start_t} s")


if __name__ == "__main__":
main()

0 comments on commit 85d2b61

Please sign in to comment.