From 4fc92d26c9cef4d072d40723dafa23afada47d23 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 15 Feb 2024 11:08:53 +0000 Subject: [PATCH 01/17] add log --- onediff_diffusers_extensions/examples/text_to_image_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py index dbc993c2f..b7a11ceff 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py @@ -11,6 +11,9 @@ from onediff.infer_compiler import oneflow_compile from onediff.schedulers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline +import diffusers + +diffusers.logging.set_verbosity_info() parser = argparse.ArgumentParser() parser.add_argument( From b8ba7708a2f5af412dc6981b2775c2e6c00ef2a5 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 15 Feb 2024 12:42:48 +0000 Subject: [PATCH 02/17] add example --- .../examples/text_to_image_sdxl.py | 3 - .../examples/text_to_image_sdxl_reuse_pipe.py | 135 ++++++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py index b7a11ceff..dbc993c2f 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py @@ -11,9 +11,6 @@ from onediff.infer_compiler import oneflow_compile from onediff.schedulers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline -import diffusers - -diffusers.logging.set_verbosity_info() parser = argparse.ArgumentParser() parser.add_argument( diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py new file mode 100644 index 000000000..f1a5bf266 --- /dev/null +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -0,0 +1,135 @@ +""" +Torch run example: python examples/text_to_image_sdxl.py +Compile to oneflow graph example: python examples/text_to_image_sdxl.py +""" +import os +import argparse + +import oneflow as flow +import torch + +from onediff.infer_compiler import oneflow_compile +from onediff.schedulers import EulerDiscreteScheduler +from diffusers import StableDiffusionXLPipeline +import diffusers + +diffusers.logging.set_verbosity_info() + +parser = argparse.ArgumentParser() +parser.add_argument( + "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" +) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--height", type=int, default=1024) +parser.add_argument("--width", type=int, default=1024) +parser.add_argument("--n_steps", type=int, default=30) +parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument( + "--compile_unet", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +parser.add_argument( + "--compile_vae", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +parser.add_argument( + "--run_multiple_resolutions", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, +) +args = parser.parse_args() + +# Normal SDXL pipeline init. +OUTPUT_TYPE = "pil" + +# SDXL base: StableDiffusionXLPipeline +scheduler = EulerDiscreteScheduler.from_pretrained(args.base, subfolder="scheduler") +base = StableDiffusionXLPipeline.from_pretrained( + args.base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, +) +base.to("cuda") + +new_base = StableDiffusionXLPipeline.from_pretrained( + "dataautogpt3/OpenDalleV1.1", + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, +) +new_base.to("cuda") + +# Compile unet with oneflow +if args.compile_unet: + print("Compiling unet with oneflow.") + base.unet = oneflow_compile(base.unet) + +# Compile vae with oneflow +if args.compile_vae: + print("Compiling vae with oneflow.") + base.vae.decoder = oneflow_compile(base.vae.decoder) + +# Warmup with run +# Will do compilatioin in the first run +print("Warmup with running graphs...") +torch.manual_seed(args.seed) +image = base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + output_type=OUTPUT_TYPE, +).images + +# Normal SDXL run +print("Normal SDXL run...") +torch.manual_seed(args.seed) +image = base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + output_type=OUTPUT_TYPE, +).images +image[0].save(f"h{args.height}-w{args.width}-{args.saved_image}") + + +# Should have no compilation for these new input shape +print("Test run with multiple resolutions...") +if args.run_multiple_resolutions: + sizes = [960, 720, 896, 768] + if "CI" in os.environ: + sizes = [360] + for h in sizes: + for w in sizes: + image = base( + prompt=args.prompt, + height=h, + width=w, + num_inference_steps=args.n_steps, + output_type=OUTPUT_TYPE, + ).images + + +# print("Test run with other another uncommon resolution...") +# if args.run_multiple_resolutions: +# h = 544 +# w = 408 +# image = base( +# prompt=args.prompt, +# height=h, +# width=w, +# num_inference_steps=args.n_steps, +# output_type=OUTPUT_TYPE, +# ).images From 1ef944770b2bfafd8c5c2db9a1caaa0e9231d96d Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 15 Feb 2024 13:10:41 +0000 Subject: [PATCH 03/17] add demo --- .../examples/text_to_image_sdxl_reuse_pipe.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index f1a5bf266..d0364a6aa 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -9,6 +9,7 @@ import torch from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler import oneflow_compiler_config from onediff.schedulers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline import diffusers @@ -70,6 +71,7 @@ ) new_base.to("cuda") +oneflow_compiler_config.mlir_enable_inference_optimization = False # Compile unet with oneflow if args.compile_unet: print("Compiling unet with oneflow.") @@ -92,6 +94,13 @@ output_type=OUTPUT_TYPE, ).images +# Update the unet and vae +# load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers. +# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict +base.unet.load_state_dict(new_base.unet.state_dict()) +base.vae.decoder.load_state_dict(new_base.vae.decoder.state_dict()) +del new_base + # Normal SDXL run print("Normal SDXL run...") torch.manual_seed(args.seed) From 6cfb4429303157652c6f3cbda9e797cbb400cc47 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 15 Feb 2024 14:23:43 +0000 Subject: [PATCH 04/17] add check update --- .../examples/text_to_image_sdxl_reuse_pipe.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index d0364a6aa..2d4fc6f7e 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -12,9 +12,8 @@ from onediff.infer_compiler import oneflow_compiler_config from onediff.schedulers import EulerDiscreteScheduler from diffusers import StableDiffusionXLPipeline -import diffusers - -diffusers.logging.set_verbosity_info() +# import diffusers +# diffusers.logging.set_verbosity_info() parser = argparse.ArgumentParser() parser.add_argument( @@ -94,6 +93,19 @@ output_type=OUTPUT_TYPE, ).images +# import numpy as np +# base_state_dict = base.unet.state_dict() +# new_state_dict = new_base.unet.state_dict() +# for k, w in base_state_dict.items(): +# if k in new_state_dict: +# if not np.allclose(w.detach().cpu().numpy(), new_state_dict[k].detach().cpu().numpy(), atol=1e-3): +# print(f"Parameter {k} is different.") + +w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() +new_w = new_base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() +import numpy as np +assert not np.allclose(w, new_w, atol=1e-3) + # Update the unet and vae # load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers. # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict @@ -101,6 +113,12 @@ base.vae.decoder.load_state_dict(new_base.vae.decoder.state_dict()) del new_base +print("check whether the weights are updated") +updated_w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() +assert np.allclose(updated_w, new_w, atol=1e-3) +updated_w_oflow = base.unet.add_embedding.linear_1.oneflow_module.weight.detach().cpu().numpy() +assert np.allclose(updated_w_oflow, new_w, atol=1e-3) + # Normal SDXL run print("Normal SDXL run...") torch.manual_seed(args.seed) From 7a920db689f4ad25f7165a8dbacb031674a755b4 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 20 Mar 2024 17:52:44 +0800 Subject: [PATCH 05/17] update example --- .../examples/text_to_image_sdxl_reuse_pipe.py | 92 ++++++++++++------- 1 file changed, 58 insertions(+), 34 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 2d4fc6f7e..1c4cd23ad 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -61,25 +61,21 @@ ) base.to("cuda") -new_base = StableDiffusionXLPipeline.from_pretrained( - "dataautogpt3/OpenDalleV1.1", - scheduler=scheduler, - torch_dtype=torch.float16, - variant=args.variant, - use_safetensors=True, -) -new_base.to("cuda") oneflow_compiler_config.mlir_enable_inference_optimization = False # Compile unet with oneflow if args.compile_unet: print("Compiling unet with oneflow.") - base.unet = oneflow_compile(base.unet) + compiled_unet = oneflow_compile(base.unet) + compiled_unet_eager = base.unet + base.unet = compiled_unet # Compile vae with oneflow if args.compile_vae: print("Compiling vae with oneflow.") - base.vae.decoder = oneflow_compile(base.vae.decoder) + compiled_decoder = oneflow_compile(base.vae.decoder) + compiled_decoder_eager = base.vae.decoder + base.vae.decoder = compiled_decoder # Warmup with run # Will do compilatioin in the first run @@ -90,46 +86,73 @@ height=args.height, width=args.width, num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, ).images +del base -# import numpy as np -# base_state_dict = base.unet.state_dict() -# new_state_dict = new_base.unet.state_dict() -# for k, w in base_state_dict.items(): -# if k in new_state_dict: -# if not np.allclose(w.detach().cpu().numpy(), new_state_dict[k].detach().cpu().numpy(), atol=1e-3): -# print(f"Parameter {k} is different.") +torch.cuda.empty_cache() + +print("loading new base") +new_base = StableDiffusionXLPipeline.from_single_file( + "dataautogpt3/OpenDalleV1.1", + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, +) +new_base.to("cuda") + +print("New base running by torch backend") +image = new_base( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), + output_type=OUTPUT_TYPE, +).images +image[0].save(f"new_base_without_graph_h{args.height}-w{args.width}-{args.saved_image}") +image_eager = image[0] -w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() -new_w = new_base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() -import numpy as np -assert not np.allclose(w, new_w, atol=1e-3) # Update the unet and vae # load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers. # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict -base.unet.load_state_dict(new_base.unet.state_dict()) -base.vae.decoder.load_state_dict(new_base.vae.decoder.state_dict()) -del new_base +print("Loading state_dict of new base into compiled graph") +compiled_unet_eager.load_state_dict(new_base.unet.state_dict()) +compiled_decoder_eager.load_state_dict(new_base.vae.decoder.state_dict()) -print("check whether the weights are updated") -updated_w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() -assert np.allclose(updated_w, new_w, atol=1e-3) -updated_w_oflow = base.unet.add_embedding.linear_1.oneflow_module.weight.detach().cpu().numpy() -assert np.allclose(updated_w_oflow, new_w, atol=1e-3) +new_base.unet = compiled_unet +new_base.vae.decoder = compiled_decoder + +torch.cuda.empty_cache() +# print("check whether the weights are updated") +# updated_w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() +# assert np.allclose(updated_w, new_w, atol=1e-3) +# updated_w_oflow = base.unet.add_embedding.linear_1.oneflow_module.weight.detach().cpu().numpy() +# assert np.allclose(updated_w_oflow, new_w, atol=1e-3) # Normal SDXL run -print("Normal SDXL run...") -torch.manual_seed(args.seed) -image = base( +print("Re-use the compiled graph") +image = new_base( prompt=args.prompt, height=args.height, width=args.width, num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, ).images -image[0].save(f"h{args.height}-w{args.width}-{args.saved_image}") +image[0].save(f"new_base_reuse_graph_h{args.height}-w{args.width}-{args.saved_image}") +image_graph = image[0] + +from skimage.metrics import structural_similarity +import numpy as np + +ssim = structural_similarity( + np.array(image_eager), np.array(image_graph), channel_axis=-1, data_range=255 +) +print(f"ssim between naive torch and re-used graph is {ssim}") # Should have no compilation for these new input shape @@ -140,11 +163,12 @@ sizes = [360] for h in sizes: for w in sizes: - image = base( + image = new_base( prompt=args.prompt, height=h, width=w, num_inference_steps=args.n_steps, + generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, ).images From 23dae2dabeccea49e2f304a1a47d03e059ef1db1 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 20 Mar 2024 17:53:53 +0800 Subject: [PATCH 06/17] update --- .../examples/text_to_image_sdxl_reuse_pipe.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 1c4cd23ad..8c5346a34 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -127,11 +127,6 @@ new_base.vae.decoder = compiled_decoder torch.cuda.empty_cache() -# print("check whether the weights are updated") -# updated_w = base.unet.add_embedding.linear_1.weight.detach().cpu().numpy() -# assert np.allclose(updated_w, new_w, atol=1e-3) -# updated_w_oflow = base.unet.add_embedding.linear_1.oneflow_module.weight.detach().cpu().numpy() -# assert np.allclose(updated_w_oflow, new_w, atol=1e-3) # Normal SDXL run print("Re-use the compiled graph") From 0ff9a6b785ba08c1f60da3da70d04d2d3c33b043 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:59:09 +0800 Subject: [PATCH 07/17] Update text_to_image_sdxl_reuse_pipe.py --- .../examples/text_to_image_sdxl_reuse_pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 8c5346a34..e2f7a7e34 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -94,7 +94,7 @@ torch.cuda.empty_cache() print("loading new base") -new_base = StableDiffusionXLPipeline.from_single_file( +new_base = StableDiffusionXLPipeline.from_pretrained( "dataautogpt3/OpenDalleV1.1", scheduler=scheduler, torch_dtype=torch.float16, From d0801f3c74c6bb7497f95b1418464e16f85ceaec Mon Sep 17 00:00:00 2001 From: WangYi Date: Sat, 23 Mar 2024 21:30:28 +0800 Subject: [PATCH 08/17] add example to ci --- .github/workflows/examples.yml | 2 ++ .../examples/text_to_image_sdxl_reuse_pipe.py | 30 ++++++++++++++----- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 9e02c49c7..d8c0c118b 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -231,6 +231,8 @@ jobs: run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_turbo.py --compile true --base /share_nfs/hf_models/sdxl-turbo - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py + - if: matrix.test-suite == 'diffusers_examples' + run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/sdxl_lightning_8step.safetensors --guidance_scale 0. --n_steps 8 - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index e2f7a7e34..8a4acf8b6 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -19,6 +19,9 @@ parser.add_argument( "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" ) +parser.add_argument( + "--new_base", type=str, default="dataautogpt3/OpenDalleV1.1", +) parser.add_argument("--variant", type=str, default="fp16") parser.add_argument( "--prompt", @@ -28,6 +31,7 @@ parser.add_argument("--height", type=int, default=1024) parser.add_argument("--width", type=int, default=1024) parser.add_argument("--n_steps", type=int, default=30) +parser.add_argument("--guidance_scale", type=float, default=7.5) parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png") parser.add_argument("--seed", type=int, default=1) parser.add_argument( @@ -88,19 +92,29 @@ num_inference_steps=args.n_steps, generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, ).images del base torch.cuda.empty_cache() print("loading new base") -new_base = StableDiffusionXLPipeline.from_pretrained( - "dataautogpt3/OpenDalleV1.1", - scheduler=scheduler, - torch_dtype=torch.float16, - variant=args.variant, - use_safetensors=True, -) +if str(args.new_base).endswith(".safetensors"): + new_base = StableDiffusionXLPipeline.from_single_file( + args.new_base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, + ) +else: + new_base = StableDiffusionXLPipeline.from_pretrained( + args.new_base, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, + ) new_base.to("cuda") print("New base running by torch backend") @@ -111,6 +125,7 @@ num_inference_steps=args.n_steps, generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, ).images image[0].save(f"new_base_without_graph_h{args.height}-w{args.width}-{args.saved_image}") image_eager = image[0] @@ -137,6 +152,7 @@ num_inference_steps=args.n_steps, generator=torch.manual_seed(0), output_type=OUTPUT_TYPE, + guidance_scale=args.guidance_scale, ).images image[0].save(f"new_base_reuse_graph_h{args.height}-w{args.width}-{args.saved_image}") image_graph = image[0] From 5fcc919054e00c3904d4b1ac34a9850761fde737 Mon Sep 17 00:00:00 2001 From: WangYi Date: Sat, 23 Mar 2024 21:59:37 +0800 Subject: [PATCH 09/17] add _torch_module member to DeployableModule --- onediff_diffusers_extensions/README.md | 26 +++++++++++++++++++ .../examples/text_to_image_sdxl_reuse_pipe.py | 6 ++--- .../infer_compiler/with_oneflow_compile.py | 1 + 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/onediff_diffusers_extensions/README.md b/onediff_diffusers_extensions/README.md index ed6c88eb4..fd2d308d1 100644 --- a/onediff_diffusers_extensions/README.md +++ b/onediff_diffusers_extensions/README.md @@ -430,6 +430,32 @@ We tested the performance of `set_adapters`, still using the five LoRA models me - While traversing the submodules of the model, we observed that the `getattr` time overhead of OneDiff's `DeployableModule` is high. Because the parameters of DeployableModule share the same address as the PyTorch module it wraps, we choose to traverse `DeployableModule._torch_module`, greatly improving traversal efficiency. +## Compiled graph re-using + +When switching models, if the new model has the same structure as the old model, you can re-use the previously compiled graph, which means you don't need to compile the new model again, which significantly reduces the time it takes you to switch models. + +Here is a pseudo code, to get detailed usage, please refer to [text_to_image_sdxl_reuse_pipe](./examples/text_to_image_sdxl_reuse_pipe.py): + +```python +base = StableDiffusionPipeline(...) +compiled_unet = oneflow_compile(base.unet) +base.unet = compiled_unet +# This step needs some time to compile the UNet +base(prompt) + +new_base = StableDiffusionPipeline(...) +# Re-use the compiled graph by loading the new state dict into the `_torch_module` member of the object returned by `oneflow_compile` +compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict()) +# After loading the new state dict into the `compiled_unet._torch_module`, the weights of the compiled_unet are updated too +new_base.unet = compiled_unet +# This step doesn't need additional time to compile the UNet again because +# new_base.unet is already compiled +new_base(prompt) +``` + +> Note: Please make sure that your PyTorch version is **at least 2.1.0**, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to **0**. And the feature is not supported for quantized model. + + ## Quantization **Note**: Quantization feature is only supported by **OneDiff Enterprise**. diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 8a4acf8b6..767009c27 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -71,14 +71,12 @@ if args.compile_unet: print("Compiling unet with oneflow.") compiled_unet = oneflow_compile(base.unet) - compiled_unet_eager = base.unet base.unet = compiled_unet # Compile vae with oneflow if args.compile_vae: print("Compiling vae with oneflow.") compiled_decoder = oneflow_compile(base.vae.decoder) - compiled_decoder_eager = base.vae.decoder base.vae.decoder = compiled_decoder # Warmup with run @@ -135,8 +133,8 @@ # load_state_dict(state_dict, strict=True, assign=False), assign is False means copying them inplace into the module’s current parameters and buffers. # Reference: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict print("Loading state_dict of new base into compiled graph") -compiled_unet_eager.load_state_dict(new_base.unet.state_dict()) -compiled_decoder_eager.load_state_dict(new_base.vae.decoder.state_dict()) +compiled_unet._torch_module.load_state_dict(new_base.unet.state_dict()) +compiled_decoder._torch_module.load_state_dict(new_base.vae.decoder.state_dict()) new_base.unet = compiled_unet new_base.vae.decoder = compiled_decoder diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index d94bcc3c3..b3e4d2906 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -205,6 +205,7 @@ def __init__( get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module), ) object.__setattr__(self, "_modules", torch_module._modules) + object.__setattr__(self, "_torch_module", torch_module) self._deployable_module_use_graph = use_graph self._deployable_module_enable_dynamic = dynamic self._deployable_module_options = options From 6f29f03e190caf703f0fa0825a8540be1b32b2c8 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Sat, 23 Mar 2024 22:01:30 +0800 Subject: [PATCH 10/17] Update text_to_image_sdxl_reuse_pipe.py --- .../examples/text_to_image_sdxl_reuse_pipe.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 767009c27..54ae3fd48 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -1,7 +1,3 @@ -""" -Torch run example: python examples/text_to_image_sdxl.py -Compile to oneflow graph example: python examples/text_to_image_sdxl.py -""" import os import argparse From 83fb07aeb15431d538fbd5b891f410b404e025f5 Mon Sep 17 00:00:00 2001 From: WangYi Date: Sat, 23 Mar 2024 22:39:13 +0800 Subject: [PATCH 11/17] fix ci yaml bug --- .github/workflows/examples.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index d8c0c118b..70aabeaa4 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -232,7 +232,7 @@ jobs: - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py - if: matrix.test-suite == 'diffusers_examples' - run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/sdxl_lightning_8step.safetensors --guidance_scale 0. --n_steps 8 + run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/sdxl_lightning_8step.safetensors --guidance_scale 0. --n_steps 8 - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' From 804d70af8d55570c0501623f9ae0950e04567112 Mon Sep 17 00:00:00 2001 From: WangYi Date: Sun, 24 Mar 2024 08:55:36 +0800 Subject: [PATCH 12/17] add missing dependancy omegaconf --- onediff_diffusers_extensions/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index ba6d36b82..a14e275f0 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -26,6 +26,7 @@ def get_version(): "accelerate", "torch", "onefx", + "omegaconf", ], classifiers=[ "Development Status :: 5 - Production/Stable", From eb5185a54c8c6ee9991a50fbda1e43ecc03caf68 Mon Sep 17 00:00:00 2001 From: WangYi Date: Sun, 24 Mar 2024 08:59:31 +0800 Subject: [PATCH 13/17] fix bug of nccl symbol not found --- .../examples/text_to_image_sdxl_reuse_pipe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 54ae3fd48..3fbbebd1d 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -1,7 +1,6 @@ import os import argparse -import oneflow as flow import torch from onediff.infer_compiler import oneflow_compile From 2cb29fef49c6bf1c7150985c8350e26650de787b Mon Sep 17 00:00:00 2001 From: WangYi Date: Sun, 24 Mar 2024 10:01:08 +0800 Subject: [PATCH 14/17] use another sdxl model --- .github/workflows/examples.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 70aabeaa4..932d92ebb 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -232,7 +232,7 @@ jobs: - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py - if: matrix.test-suite == 'diffusers_examples' - run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/sdxl_lightning_8step.safetensors --guidance_scale 0. --n_steps 8 + run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /data/home/wangyi/models/base/SDXLRonghua_v40.safetensors - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' From 2c429870eed18176b76935df1b322fd093fb56eb Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Sun, 24 Mar 2024 10:31:48 +0800 Subject: [PATCH 15/17] Update examples.yml --- .github/workflows/examples.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 932d92ebb..9e02c49c7 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -231,8 +231,6 @@ jobs: run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_turbo.py --compile true --base /share_nfs/hf_models/sdxl-turbo - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py - - if: matrix.test-suite == 'diffusers_examples' - run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /data/home/wangyi/models/base/SDXLRonghua_v40.safetensors - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' From 39ced0b93646a865ee91f94813b5299f7d445b85 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Sun, 24 Mar 2024 10:32:05 +0800 Subject: [PATCH 16/17] Update setup.py --- onediff_diffusers_extensions/setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index a14e275f0..ba6d36b82 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -26,7 +26,6 @@ def get_version(): "accelerate", "torch", "onefx", - "omegaconf", ], classifiers=[ "Development Status :: 5 - Production/Stable", From 80708871a5b6839dca916d7aeb86fa8e8991f11a Mon Sep 17 00:00:00 2001 From: WangYi Date: Sun, 24 Mar 2024 16:44:20 +0800 Subject: [PATCH 17/17] another model --- .github/workflows/examples.yml | 2 ++ onediff_diffusers_extensions/setup.py | 1 + 2 files changed, 3 insertions(+) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 9e02c49c7..abc6efa5c 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -231,6 +231,8 @@ jobs: run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_turbo.py --compile true --base /share_nfs/hf_models/sdxl-turbo - if: matrix.test-suite == 'diffusers_examples' run: docker exec -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 -m pytest -v onediff_diffusers_extensions/tests/test_lora.py + - if: matrix.test-suite == 'diffusers_examples' + run: docker exec -w /src/onediff/onediff_diffusers_extensions -e ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION=0 ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_reuse_pipe.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --new_base /share_nfs/hf_models/dataautogpt3-OpenDalleV1.1 - name: Shutdown docker for ComfyUI Test if: matrix.test-suite == 'comfy' diff --git a/onediff_diffusers_extensions/setup.py b/onediff_diffusers_extensions/setup.py index ba6d36b82..a14e275f0 100644 --- a/onediff_diffusers_extensions/setup.py +++ b/onediff_diffusers_extensions/setup.py @@ -26,6 +26,7 @@ def get_version(): "accelerate", "torch", "onefx", + "omegaconf", ], classifiers=[ "Development Status :: 5 - Production/Stable",