From 651606fe8b57b32d4dcbb0dad6ddc716048d29fa Mon Sep 17 00:00:00 2001 From: lixiang007666 <88304454@qq.com> Date: Wed, 24 Apr 2024 03:07:49 +0000 Subject: [PATCH] Refine SVD diffusers example for graph save/load --- benchmarks/image_to_video.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/benchmarks/image_to_video.py b/benchmarks/image_to_video.py index a1c707d69..8cef18db2 100644 --- a/benchmarks/image_to_video.py +++ b/benchmarks/image_to_video.py @@ -48,6 +48,8 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default=MODEL) + parser.add_argument("--save_graph", action="store_true") + parser.add_argument("--load_graph", action="store_true") parser.add_argument("--variant", type=str, default=VARIANT) parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE) parser.add_argument("--scheduler", type=str, default=SCHEDULER) @@ -251,10 +253,22 @@ def get_kwarg_inputs(): return kwarg_inputs if args.warmups > 0: - print("Begin warmup") - for _ in range(args.warmups): - pipe(**get_kwarg_inputs()) - print("End warmup") + if args.load_graph: + print("Loading graphs to avoid compilation...") + start_t = time.time() + pipe.unet.load_graph("base_unet_compiled", run_warmup=True) + pipe.vae.decoder.load_graph("base_vae_compiled", run_warmup=True) + end_t = time.time() + print(f"Loading graph elapsed: {end_t - start_t} s") + print("Begin warmup") + for _ in range(args.warmups): + pipe(**get_kwarg_inputs()) + print("End warmup") + else: + print("Begin warmup") + for _ in range(args.warmups): + pipe(**get_kwarg_inputs()) + print("End warmup") kwarg_inputs = get_kwarg_inputs() iter_profiler = IterationProfiler() @@ -280,6 +294,14 @@ def get_kwarg_inputs(): else: print("Please set `--output-video` to save the output video") + if args.save_graph: + print("Saving graphs...") + start_t = time.time() + pipe.unet.save_graph("base_unet_compiled") + pipe.vae.decoder.save_graph("base_vae_compiled") + end_t = time.time() + print(f"save graphs elapsed: {end_t - start_t} s") + if __name__ == "__main__": main()