diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 3fbb29791..b53258564 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -11,7 +11,7 @@ assignees: strint #### A clear and concise description of what the bug is. ### Your environment -#### OS +#### OS #### OneDiff git commit id diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index c1a6ed3b9..76afb4b37 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -242,7 +242,7 @@ jobs: SDXL_BASE: ${{ env.SDXL_BASE }} UNET_INT8: ${{ env.UNET_INT8 }} SILICON_ONEDIFF_LICENSE_KEY: ${{ secrets.SILICON_ONEDIFF_LICENSE_KEY }} - + - name: Setup docker for WebUI Test if: matrix.test-suite == 'webui' run: | @@ -351,7 +351,7 @@ jobs: run: | docker exec -w /src/onediff ${{ env.CONTAINER_NAME }} python3 onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py --model /share_nfs/hf_models/stable-diffusion-v1-5-int8 --width 512 --height 512 --saved_image /src/onediff/output_enterprise_sd.png docker exec -w /src/onediff ${{ env.CONTAINER_NAME }} python3 tests/test_quantitative_quality.py - + - name: Install Requirements for WebUI if: matrix.test-suite == 'webui' run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..c0c773de1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,30 @@ +exclude: '((generator.py)|(generated/.*))$' +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-toml + - id: check-yaml + exclude: packaging/.* + args: + - --allow-multiple-documents + - id: mixed-line-ending + args: [--fix=lf] + - id: end-of-file-fixer + + - repo: https://github.com/omnilib/ufmt + rev: v1.3.3 + hooks: + - id: ufmt + additional_dependencies: + - black == 22.3.0 + - usort == 1.0.2 + +# - repo: https://github.com/PyCQA/pydocstyle +# rev: 6.1.1 +# hooks: +# - id: pydocstyle +# diff --git a/README.md b/README.md index a43d3bc7c..5d4deea51 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ For example: + [onediff Enterprise Edition](#onediff-enterprise-edition) -## Documentation +## Documentation onediff is the abbreviation of "**one** line of code to accelerate **diff**usion models". ### Use with HF diffusers and ComfyUI @@ -180,8 +180,20 @@ python3 -m pip install --pre onediff - From source ``` git clone https://github.com/siliconflow/onediff.git +``` +``` cd onediff && python3 -m pip install -e . ``` +Or install for development: +``` +# install for dev +cd onediff && python3 -m pip install -e '.[dev]' + +# code formatting and linting +pip3 install pre-commit +pre-commit install +pre-commit run --all-files +``` > **_NOTE:_** If you intend to utilize plugins for ComfyUI/StableDiffusion-WebUI, we highly recommend installing OneDiff from the source rather than PyPI. This is necessary as you'll need to manually copy (or create a soft link) for the relevant code into the extension folder of these UIs/Libs. diff --git a/README_ENTERPRISE.md b/README_ENTERPRISE.md index 8b19c5930..c7f48073e 100644 --- a/README_ENTERPRISE.md +++ b/README_ENTERPRISE.md @@ -122,7 +122,7 @@ Ensure that you have installed [OneDiff ComfyUI Nodes](onediff_comfy_nodes/READM For more information and to **access the model files and Workflow below**, please visit [Hugging Face - stable-diffusion-v1-5-onediff-enterprise-v1](https://huggingface.co/siliconflow/stable-diffusion-v1-5-onediff-comfy-enterprise-v1/tree/main). -
+
Download the required model files 1. Download the [`v1-5-pruned.safetensors`](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.safetensors) file and place it in the `ComfyUI/models/checkpoints/` directory: @@ -149,7 +149,7 @@ Click the links below to view the workflow images, or load them directly into Co For more information and to **access the model files and Workflow below**, please visit [Hugging Face - stable-diffusion-v2-1-onediff-enterprise](https://huggingface.co/siliconflow/stable-diffusion-v2-1-onediff-comfy-enterprise/tree/main). -
+
Download the required model files 1. Download the [`v2-1_768-ema-pruned.zip`](https://huggingface.co/siliconflow/stable-diffusion-v2-1-onediff-comfy-enterprise/blob/main/v2-1_768-ema-pruned.zip) file and unzip ,then place the .safetensors in the `ComfyUI/models/checkpoints/` directory: @@ -166,7 +166,7 @@ wget https://huggingface.co/siliconflow/stable-diffusion-v2-1-onediff-comfy-ente
-Click the links below to view the workflow images, or load them directly into ComfyUI. +Click the links below to view the workflow images, or load them directly into ComfyUI. - Workflow: [SD 2.1](https://huggingface.co/siliconflow/stable-diffusion-v2-1-onediff-comfy-enterprise/blob/main/onediff_stable_diffusion_2_1.png) @@ -174,7 +174,7 @@ Click the links below to view the workflow images, or load them directly into Co For model details and to **access the model files and Workflow below**, please visit [Hugging Face - sdxl-base-1.0-onediff-comfy-enterprise-v1](https://huggingface.co/siliconflow/sdxl-base-1.0-onediff-comfy-enterprise-v1/tree/main). -
+
Download the required model files 1. Download the [`sd_xl_base_1.0.safetensors`](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) file and place it in the `ComfyUI/models/checkpoints/` directory: @@ -373,20 +373,20 @@ To download the necessary models, please visit the [siliconflow/stable-video-dif Run [image_to_video.py](benchmarks/image_to_video.py): ```bash -python3 benchmarks/image_to_video.py \ - --model $model_path \ - --input-image path/to/input_image.jpg \ - --output-video path/to/output_image.mp4 +python3 benchmarks/image_to_video.py \ + --model $model_path \ + --input-image path/to/input_image.jpg \ + --output-video path/to/output_image.mp4 ``` #### SVD + DeepCache ```bash -python3 benchmarks/image_to_video.py \ - --model $model_path \ - --deepcache \ - --input-image path/to/input_image.jpg \ - --output-video path/to/output_image.mp4 +python3 benchmarks/image_to_video.py \ + --model $model_path \ + --deepcache \ + --input-image path/to/input_image.jpg \ + --output-video path/to/output_image.mp4 ``` ## Quantitative model diff --git a/benchmarks/docker/_utils.py b/benchmarks/docker/_utils.py index 3c93322e5..0a8e35edf 100644 --- a/benchmarks/docker/_utils.py +++ b/benchmarks/docker/_utils.py @@ -1,12 +1,13 @@ import hashlib import os import subprocess -import yaml -from git import Repo +import yaml from _logger import logger +from git import Repo + def load_yaml(*, file): if not os.path.exists(file): diff --git a/benchmarks/docker/main.py b/benchmarks/docker/main.py index a59d0d4fd..31d74c921 100644 --- a/benchmarks/docker/main.py +++ b/benchmarks/docker/main.py @@ -1,21 +1,21 @@ import argparse -from datetime import datetime import os import sys +from datetime import datetime from pathlib import Path ONEDIFFBOX_ROOT = Path(os.path.abspath(__file__)).parents[0] sys.path.insert(0, str(ONEDIFFBOX_ROOT)) +from _logger import logger from _utils import ( - calculate_sha256, - setup_repo, - load_yaml, - generate_docker_file, build_image, + calculate_sha256, gen_docker_compose_yaml, + generate_docker_file, + load_yaml, + setup_repo, ) -from _logger import logger def parse_args(): @@ -23,13 +23,22 @@ def parse_args(): formatted_datetime = datetime.now().strftime("%Y%m%d-%H%M") parser.add_argument( - "-y", "--yaml", type=str, default="config/community-default.yaml", + "-y", + "--yaml", + type=str, + default="config/community-default.yaml", ) parser.add_argument( - "-i", "--image", type=str, default="onediff", + "-i", + "--image", + type=str, + default="onediff", ) parser.add_argument( - "-t", "--tag", type=str, default=f"benchmark", + "-t", + "--tag", + type=str, + default=f"benchmark", ) parser.add_argument( "-o", @@ -39,10 +48,17 @@ def parse_args(): help="the output directory of Dockerfile and Docker-compose file", ) parser.add_argument( - "-c", "--context", type=str, default=".", help="the path to build context", + "-c", + "--context", + type=str, + default=".", + help="the path to build context", ) parser.add_argument( - "-q", "--quiet", action="store_true", help="quiet mode", + "-q", + "--quiet", + action="store_true", + help="quiet mode", ) args = parser.parse_args() return args @@ -77,7 +93,10 @@ def parse_args(): envs = image_config.pop("envs", []) volumes = image_config.pop( - "volumes", ["$BENCHMARK_MODEL_PATH:/benchmark_model:ro",], + "volumes", + [ + "$BENCHMARK_MODEL_PATH:/benchmark_model:ro", + ], ) compose_file, run_command = gen_docker_compose_yaml( f"onediff-benchmark-{version}", image_name, envs, volumes, args.output diff --git a/benchmarks/image_to_video.py b/benchmarks/image_to_video.py index fcc9e19cd..2fc24469d 100644 --- a/benchmarks/image_to_video.py +++ b/benchmarks/image_to_video.py @@ -30,19 +30,20 @@ CACHE_INTERVAL = 3 CACHE_BRANCH = 0 -import os +import argparse import importlib import inspect -import argparse -import time import json +import os import random +import time + from PIL import Image, ImageDraw -import oneflow as flow +import oneflow as flow # usort: skip import torch -from onediffx import compile_pipe, OneflowCompileOptions -from diffusers.utils import load_image, export_to_video +from diffusers.utils import export_to_video, load_image +from onediffx import compile_pipe, OneflowCompileOptions def parse_args(): @@ -84,10 +85,14 @@ def parse_args(): default=ATTENTION_FP16_SCORE_ACCUM_MAX_M, ) parser.add_argument( - "--alter-height", type=int, default=ALTER_HEIGHT, + "--alter-height", + type=int, + default=ALTER_HEIGHT, ) parser.add_argument( - "--alter-width", type=int, default=ALTER_WIDTH, + "--alter-width", + type=int, + default=ALTER_WIDTH, ) return parser.parse_args() @@ -110,7 +115,8 @@ def load_pipe( from diffusers import ControlNetModel controlnet = ControlNetModel.from_pretrained( - controlnet, torch_dtype=torch.float16, + controlnet, + torch_dtype=torch.float16, ) extra_kwargs["controlnet"] = controlnet if os.path.exists(os.path.join(model_name, "calibrate_info.txt")): @@ -218,7 +224,12 @@ def main(): control_image = Image.new("RGB", (width, height)) draw = ImageDraw.Draw(control_image) draw.ellipse( - (width // 4, height // 4, width // 4 * 3, height // 4 * 3,), + ( + width // 4, + height // 4, + width // 4 * 3, + height // 4 * 3, + ), fill=(255, 255, 255), ) del draw diff --git a/benchmarks/instant_id.py b/benchmarks/instant_id.py index ace93b869..f3b1eec3d 100644 --- a/benchmarks/instant_id.py +++ b/benchmarks/instant_id.py @@ -26,22 +26,23 @@ CACHE_LAYER_ID = 0 CACHE_BLOCK_ID = 0 -import sys -import os +import argparse import importlib import inspect -import argparse -import time import json -import torch -from PIL import Image, ImageDraw -import numpy as np +import os +import sys +import time + import cv2 -from huggingface_hub import snapshot_download +import numpy as np +import torch from diffusers.utils import load_image +from huggingface_hub import snapshot_download from insightface.app import FaceAnalysis +from PIL import Image, ImageDraw -import oneflow as flow +import oneflow as flow # usort: skip from onediffx import compile_pipe @@ -97,7 +98,8 @@ def load_pipe( from diffusers import ControlNetModel controlnet = ControlNetModel.from_pretrained( - controlnet, torch_dtype=torch.float16, + controlnet, + torch_dtype=torch.float16, ) extra_kwargs["controlnet"] = controlnet if os.path.exists(os.path.join(model_name, "calibrate_info.txt")): @@ -183,13 +185,14 @@ def main(): if args.repo is None: custom_pipeline = args.custom_pipeline from diffusers import DiffusionPipeline + pipeline_cls = DiffusionPipeline else: sys.path.insert(0, args.repo) from pipeline_stable_diffusion_xl_instantid import ( - StableDiffusionXLInstantIDPipeline as pipeline_cls, draw_kps, + StableDiffusionXLInstantIDPipeline as pipeline_cls, ) if os.path.exists(args.controlnet): diff --git a/benchmarks/patch_stable_cascade.py b/benchmarks/patch_stable_cascade.py index cb5a02f8d..b640c96ec 100644 --- a/benchmarks/patch_stable_cascade.py +++ b/benchmarks/patch_stable_cascade.py @@ -78,11 +78,11 @@ def pixel_shuffle(input, upscale_factor): *batch_dims, channels, height, width = input.shape assert ( - channels % (upscale_factor ** 2) == 0 + channels % (upscale_factor**2) == 0 ), "Number of channels must be divisible by the square of the upscale factor" # Calculate new channels after applying upscale_factor - new_channels = channels // (upscale_factor ** 2) + new_channels = channels // (upscale_factor**2) # Reshape input to (*batch_dims, new_channels, upscale_factor, upscale_factor, height, width) reshaped = input.reshape( @@ -143,7 +143,7 @@ def pixel_unshuffle(input, downscale_factor): # Final reshape output = permuted.reshape( *batch_dims, - channels * downscale_factor ** 2, + channels * downscale_factor**2, height // downscale_factor, width // downscale_factor, ) diff --git a/benchmarks/patch_stable_cascade_of.py b/benchmarks/patch_stable_cascade_of.py index 454a17344..2431b156a 100644 --- a/benchmarks/patch_stable_cascade_of.py +++ b/benchmarks/patch_stable_cascade_of.py @@ -1,11 +1,12 @@ +import importlib.metadata from typing import Optional + import oneflow as torch import oneflow.nn as nn import oneflow.nn.functional as F -from packaging import version -import importlib.metadata from onediff.infer_compiler.backends.oneflow.transform import transform_mgr +from packaging import version diffusers_of = transform_mgr.transform_package("diffusers") StableCascadeUnet_OF_CLS = ( @@ -114,6 +115,8 @@ def forward( return self.clf(x).to(torch.float16) +from contextlib import contextmanager + # diffusers.pipelines.stable_cascade.modeling_stable_cascade_common.StableCascadeUnet from diffusers.pipelines.stable_cascade.modeling_stable_cascade_common import ( StableCascadeUnet, @@ -121,7 +124,6 @@ def forward( # torch2oflow_class_map.update({StableCascadeUnet: StableCascadeUnetOflow}) from onediff.infer_compiler.backends.oneflow.transform import register -from contextlib import contextmanager @contextmanager diff --git a/benchmarks/run_image_to_video_benchmark.sh b/benchmarks/run_image_to_video_benchmark.sh index 4e516c871..7bcd81982 100755 --- a/benchmarks/run_image_to_video_benchmark.sh +++ b/benchmarks/run_image_to_video_benchmark.sh @@ -24,7 +24,7 @@ while getopts 'm:w:c:o:h' opt; do o) OUTPUT_FILE=$OPTARG ;; - + ?|h) echo "Usage: $(basename $0) [-m model_dir] [-w warmups] [-c compiler] [-o output_file]" echo " -m model_dir: the directory of the models, if not set, use HF models" diff --git a/benchmarks/run_sdxl_light_benchmarks.sh b/benchmarks/run_sdxl_light_benchmarks.sh index 6138a6aff..f16faf48b 100755 --- a/benchmarks/run_sdxl_light_benchmarks.sh +++ b/benchmarks/run_sdxl_light_benchmarks.sh @@ -43,5 +43,3 @@ TXT2IMG_ONEFLOW_OUTPUT_FILE=${OUTPUT_DIR}/sdxl_light_oneflow.md BENCHMARK_RESULT_TEXT="${BENCHMARK_RESULT_TEXT}\n\n### Text to Image (OneFlow)\n\n$(cat ${TXT2IMG_ONEFLOW_OUTPUT_FILE})\n\n" echo -e "${BENCHMARK_RESULT_TEXT}" > ${OUTPUT_FILE} - - diff --git a/benchmarks/run_text_to_image_benchmark.sh b/benchmarks/run_text_to_image_benchmark.sh index 512aab42f..f646b0b63 100755 --- a/benchmarks/run_text_to_image_benchmark.sh +++ b/benchmarks/run_text_to_image_benchmark.sh @@ -24,7 +24,7 @@ while getopts 'm:w:c:o:h' opt; do o) OUTPUT_FILE=$OPTARG ;; - + ?|h) echo "Usage: $(basename $0) [-m model_dir] [-w warmups] [-c compiler] [-o output_file]" echo " -m model_dir: the directory of the models, if not set, use HF models" diff --git a/benchmarks/run_text_to_image_benchmark_trt.sh b/benchmarks/run_text_to_image_benchmark_trt.sh index cf6231d2a..1aafb48d5 100755 --- a/benchmarks/run_text_to_image_benchmark_trt.sh +++ b/benchmarks/run_text_to_image_benchmark_trt.sh @@ -34,7 +34,7 @@ while getopts 'm:w:p:o:d:v:p:h' opt; do v) TRT_VERSION=$OPTARG ;; - + ?|h) echo "Usage: $(basename $0) [-m model_dir] [-w warmups] [-p prompt] [-o output_file] [-d work_dir] [-v trt_version] [-h]" echo " -m model_dir: the directory of the models, if not set, use HF models" @@ -102,7 +102,7 @@ esac python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_TAG} cd $TRT_REPO_DIR/demo/Diffusion -python3 -m pip install -r requirements.txt +python3 -m pip install -r requirements.txt if [ ! -z "${MODEL_DIR}" ]; then echo "model_dir specified, use local models" diff --git a/benchmarks/run_text_to_image_sdxl_light_benchmark.sh b/benchmarks/run_text_to_image_sdxl_light_benchmark.sh index 2a4e4aa36..e46aa7237 100755 --- a/benchmarks/run_text_to_image_sdxl_light_benchmark.sh +++ b/benchmarks/run_text_to_image_sdxl_light_benchmark.sh @@ -24,7 +24,7 @@ while getopts 'm:w:c:o:h' opt; do o) OUTPUT_FILE=$OPTARG ;; - + ?|h) echo "Usage: $(basename $0) [-m model_dir] [-w warmups] [-c compiler] [-o output_file]" echo " -m model_dir: the directory of the models, if not set, use HF models" diff --git a/benchmarks/stable_cascade.py b/benchmarks/stable_cascade.py index 7043bbed4..18c8e8ad2 100644 --- a/benchmarks/stable_cascade.py +++ b/benchmarks/stable_cascade.py @@ -42,18 +42,19 @@ DECODER_CONTROL_IMAGE = None OUTPUT_IMAGE = None -import os +import argparse import importlib import inspect -import argparse -import time import json +import os +import time from contextlib import nullcontext + import torch -from PIL import Image, ImageDraw from diffusers.utils import load_image +from PIL import Image, ImageDraw -import oneflow as flow +import oneflow as flow # usort: skip from onediffx import compile_pipe @@ -131,7 +132,10 @@ def load_pipe( if controlnet is not None: from diffusers import ControlNetModel - controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=dtype,) + controlnet = ControlNetModel.from_pretrained( + controlnet, + torch_dtype=dtype, + ) extra_kwargs["controlnet"] = controlnet if os.path.exists(os.path.join(model_name, "calibrate_info.txt")): from onediff.quantization import QuantPipeline diff --git a/benchmarks/text_to_image.py b/benchmarks/text_to_image.py index adf52144c..85ec6bb43 100644 --- a/benchmarks/text_to_image.py +++ b/benchmarks/text_to_image.py @@ -23,21 +23,25 @@ COMPILER_CONFIG = None QUANTIZE_CONFIG = None -import os +import argparse import importlib import inspect -import argparse -import time import json -import torch +import os +import time + import matplotlib.pyplot as plt import numpy as np -from PIL import Image, ImageDraw +import torch from diffusers.utils import load_image - -from onediffx import compile_pipe, quantize_pipe # quantize_pipe currently only supports the nexfort backend. from onediff.infer_compiler import oneflow_compile +from onediffx import ( # quantize_pipe currently only supports the nexfort backend. + compile_pipe, + quantize_pipe, +) +from PIL import Image, ImageDraw + def parse_args(): parser = argparse.ArgumentParser() @@ -377,7 +381,7 @@ def get_kwarg_inputs(): if iter_per_sec is not None: print(f"Iterations per second: {iter_per_sec:.3f}") if args.compiler == "oneflow": - import oneflow as flow + import oneflow as flow # usort: skip cuda_mem_after_used = flow._oneflow_internal.GetCUDAMemoryUsed() / 1024 else: @@ -387,6 +391,7 @@ def get_kwarg_inputs(): if args.print_output: from onediff.utils.import_utils import is_nexfort_available + if is_nexfort_available(): from nexfort.utils.term_image import print_image diff --git a/benchmarks/text_to_image_sdxl_light.py b/benchmarks/text_to_image_sdxl_light.py index 819b529ca..da73341d9 100644 --- a/benchmarks/text_to_image_sdxl_light.py +++ b/benchmarks/text_to_image_sdxl_light.py @@ -14,32 +14,28 @@ OUTPUT_IMAGE = None EXTRA_CALL_KWARGS = None -import os +import argparse import importlib import inspect -import argparse -import time import json +import os +import time + import torch -from PIL import Image, ImageDraw from diffusers.utils import load_image +from PIL import Image, ImageDraw -import oneflow as flow -from onediffx import compile_pipe - +import oneflow as flow # usort: skip from huggingface_hub import hf_hub_download +from onediffx import compile_pipe from safetensors.torch import load_file def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default=MODEL) - parser.add_argument( - "--repo", type=str, default=REPO - ) - parser.add_argument( - "--cpkt", type=str, default=CPKT - ) + parser.add_argument("--repo", type=str, default=REPO) + parser.add_argument("--cpkt", type=str, default=CPKT) parser.add_argument("--variant", type=str, default=VARIANT) parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE) parser.add_argument("--controlnet", type=str, default=CONTROLNET) @@ -70,8 +66,7 @@ def load_and_compile_pipe( custom_pipeline=None, controlnet=None, ): - from diffusers import StableDiffusionXLPipeline - from diffusers import EulerDiscreteScheduler + from diffusers import EulerDiscreteScheduler, StableDiffusionXLPipeline extra_kwargs = {} if custom_pipeline is not None: @@ -82,11 +77,13 @@ def load_and_compile_pipe( from diffusers import ControlNetModel controlnet = ControlNetModel.from_pretrained( - controlnet, torch_dtype=torch.float16, + controlnet, + torch_dtype=torch.float16, ) extra_kwargs["controlnet"] = controlnet if os.path.exists(os.path.join(model_name, "calibrate_info.txt")): from onediff.quantization import QuantPipeline + raise TypeError("Quantizatble SDXL-LIGHT is not supported!") # pipe = QuantPipeline.from_quantized( # pipeline_cls, model_name, torch_dtype=torch.float16, **extra_kwargs @@ -99,9 +96,7 @@ def load_and_compile_pipe( if is_lora_cpkt: pipe = StableDiffusionXLPipeline.from_pretrained( - model_name, - torch_dtype=torch.float16, - **extra_kwargs + model_name, torch_dtype=torch.float16, **extra_kwargs ).to("cuda") if os.path.isfile(os.path.join(repo_name, cpkt_name)): pipe.load_lora_weights(os.path.join(repo_name, cpkt_name)) @@ -110,22 +105,26 @@ def load_and_compile_pipe( pipe.fuse_lora() else: from diffusers import UNet2DConditionModel - unet = UNet2DConditionModel.from_config(model_name, subfolder="unet").to("cuda", torch.float16) + + unet = UNet2DConditionModel.from_config(model_name, subfolder="unet").to( + "cuda", torch.float16 + ) if os.path.isfile(os.path.join(repo_name, cpkt_name)): - unet.load_state_dict(load_file(os.path.join(repo_name, cpkt_name), device="cuda")) + unet.load_state_dict( + load_file(os.path.join(repo_name, cpkt_name), device="cuda") + ) else: from huggingface_hub import hf_hub_download - unet.load_state_dict(load_file(hf_hub_download(repo_name, cpkt_name), device="cuda")) + + unet.load_state_dict( + load_file(hf_hub_download(repo_name, cpkt_name), device="cuda") + ) pipe = StableDiffusionXLPipeline.from_pretrained( - model_name, - unet=unet, - torch_dtype=torch.float16, - **extra_kwargs + model_name, unet=unet, torch_dtype=torch.float16, **extra_kwargs ).to("cuda") pipe.scheduler = EulerDiscreteScheduler.from_config( - pipe.scheduler.config, - timestep_spacing="trailing" + pipe.scheduler.config, timestep_spacing="trailing" ) pipe.safety_checker = None pipe.to(torch.device("cuda")) @@ -192,8 +191,6 @@ def main(): n_steps = int(args.cpkt[len("sdxl_lightning_") : len("sdxl_lightning_") + 1]) - - def get_kwarg_inputs(): kwarg_inputs = dict( prompt=args.prompt, diff --git a/benchmarks/text_to_video_latte.py b/benchmarks/text_to_video_latte.py index 6f7ea6328..fdf538d58 100644 --- a/benchmarks/text_to_video_latte.py +++ b/benchmarks/text_to_video_latte.py @@ -28,21 +28,22 @@ COMPILER_CONFIG = None -import os +import argparse import importlib import inspect -import argparse -import time import json +import os import random -from PIL import Image, ImageDraw +import time + +import imageio import torch -from onediffx import compile_pipe -from diffusers.schedulers import DDIMScheduler from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder +from diffusers.schedulers import DDIMScheduler +from onediffx import compile_pipe +from PIL import Image, ImageDraw from transformers import T5EncoderModel, T5Tokenizer -import imageio def parse_args(): @@ -87,7 +88,9 @@ def parse_args(): choices=["none", "nexfort", "compile"], ) parser.add_argument( - "--compiler-config", type=str, default=COMPILER_CONFIG, + "--compiler-config", + type=str, + default=COMPILER_CONFIG, ) parser.add_argument( "--attention-fp16-score-accum-max-m", @@ -137,7 +140,7 @@ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" from models.latte_t2v import LatteT2V - from sample.pipeline_latte import LattePipeline + from sample.pipeline_latte import LattePipeline transformer_model = LatteT2V.from_pretrained( model_path, subfolder="transformer", video_length=args.video_length @@ -267,6 +270,7 @@ def get_kwarg_inputs(): ) # highest quality is 10, lowest is 0 except: print("Error when saving {}".format(args.prompt)) + else: print("Please set `--output-video` to save the output video") diff --git a/imgs/plot.py b/imgs/plot.py index d15a86669..99c414319 100644 --- a/imgs/plot.py +++ b/imgs/plot.py @@ -1,7 +1,8 @@ +import argparse + import matplotlib as mpl import matplotlib.pyplot as plt import yaml -import argparse def parse_args(): diff --git a/onediff_comfy_nodes/README.md b/onediff_comfy_nodes/README.md index 719f16278..ec010e703 100644 --- a/onediff_comfy_nodes/README.md +++ b/onediff_comfy_nodes/README.md @@ -6,13 +6,13 @@ --- -Performance of Community Edition +Performance of Community Edition Updated on January 23, 2024. Device: RTX 3090
- + **SDXL End2End Time** , Image Size 1024x1024 , Batch Size 1 , steps 20 @@ -51,7 +51,7 @@ Updated on January 23, 2024. Device: RTX 3090 ### Installation Guide This guide provides two methods to install ComfyUI and integrate it with the OneDiff library: via the Comfy CLI or directly from GitHub. -
+
Option 1: Installing via Comfy CLI 1. **Install Comfy CLI**: @@ -98,7 +98,7 @@ First, install and set up [ComfyUI](https://github.com/comfyanonymous/ComfyUI), ``` 4. **Install a Compiler Backend** - + For instructions on installing a compiler backend for OneDiff, please refer to the OneDiff GitHub repository [here](https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend). @@ -125,7 +125,7 @@ The "Load Checkpoint - OneDiff" node set `vae_speedup` : `enable` to enable VA -### Compiler Cache +### Compiler Cache **Avoid compilation time for online serving** ```shell diff --git a/onediff_comfy_nodes/__init__.py b/onediff_comfy_nodes/__init__.py index 702c7bf47..d32e96ffe 100644 --- a/onediff_comfy_nodes/__init__.py +++ b/onediff_comfy_nodes/__init__.py @@ -1,5 +1,6 @@ """OneDiff ComfyUI Speedup Module""" from onediff.utils.import_utils import is_nexfort_available, is_oneflow_available + from ._config import is_disable_oneflow_backend from ._nodes import ( ControlnetSpeedup, @@ -51,7 +52,9 @@ def lazy_load_extra_nodes(): update_node_mappings(nodes_nexfort_booster) from .extras_nodes import nodes_prompt_styler + update_node_mappings(nodes_prompt_styler) + # Lazy load all extra nodes when needed lazy_load_extra_nodes() diff --git a/onediff_comfy_nodes/_config.py b/onediff_comfy_nodes/_config.py index 35126c879..3ab5027ca 100644 --- a/onediff_comfy_nodes/_config.py +++ b/onediff_comfy_nodes/_config.py @@ -1,7 +1,8 @@ import os import sys -import torch + import folder_paths +import torch __all__ = [ "is_default_using_oneflow_backend", diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index dd1a86e57..74a864477 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -1,12 +1,16 @@ +import uuid from typing import Optional, Tuple + import folder_paths import torch -import uuid from nodes import CheckpointLoaderSimple, ControlNetLoader +from onediff.utils.import_utils import ( # type: ignore + is_nexfort_available, + is_oneflow_available, +) + from ._config import is_disable_oneflow_backend -from .modules import BoosterScheduler, BoosterExecutor, BoosterSettings -from onediff.utils.import_utils import is_nexfort_available # type: ignore -from onediff.utils.import_utils import is_oneflow_available +from .modules import BoosterExecutor, BoosterScheduler, BoosterSettings if is_oneflow_available() and not is_disable_oneflow_backend(): from .modules.oneflow import BasicOneFlowBoosterExecutor @@ -80,7 +84,9 @@ class ModelSpeedup(SpeedupMixin): @classmethod def INPUT_TYPES(s): return { - "required": {"model": ("MODEL",),}, + "required": { + "model": ("MODEL",), + }, "optional": { "custom_booster": ("CUSTOM_BOOSTER",), "inplace": ( @@ -117,7 +123,9 @@ class ControlnetSpeedup(SpeedupMixin): @classmethod def INPUT_TYPES(s): return { - "required": {"control_net": ("CONTROL_NET",),}, + "required": { + "control_net": ("CONTROL_NET",), + }, "optional": { "inplace": ( "BOOLEAN", @@ -183,7 +191,13 @@ class OneDiffControlNetLoader(ControlNetLoader): @classmethod def INPUT_TYPES(s): ret = super().INPUT_TYPES() - ret.update({"optional": {"custom_booster": ("CUSTOM_BOOSTER",),}}) + ret.update( + { + "optional": { + "custom_booster": ("CUSTOM_BOOSTER",), + } + } + ) return ret CATEGORY = "OneDiff/Loaders" @@ -207,7 +221,9 @@ def INPUT_TYPES(s): "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), "vae_speedup": (["disable", "enable"],), }, - "optional": {"custom_booster": ("CUSTOM_BOOSTER",),}, + "optional": { + "custom_booster": ("CUSTOM_BOOSTER",), + }, } CATEGORY = "OneDiff/Loaders" @@ -220,7 +236,10 @@ def __init__(self) -> None: @torch.inference_mode() def onediff_load_checkpoint( - self, ckpt_name, vae_speedup="disable", custom_booster: BoosterScheduler = None, + self, + ckpt_name, + vae_speedup="disable", + custom_booster: BoosterScheduler = None, ): modelpatcher, clip, vae = self.load_checkpoint(ckpt_name) modelpatcher = self.speedup( diff --git a/onediff_comfy_nodes/benchmarks/README.md b/onediff_comfy_nodes/benchmarks/README.md index d3435167e..d6ac6a03b 100644 --- a/onediff_comfy_nodes/benchmarks/README.md +++ b/onediff_comfy_nodes/benchmarks/README.md @@ -28,7 +28,7 @@ bash scripts/install_env.sh $COMFYUI_ROOT cd $COMFYUI_ROOT -python main.py --gpu-only --port 8188 --extra-model-paths-config path/to/onediff/tests/comfyui/extra_model_paths.yaml +python main.py --gpu-only --port 8188 --extra-model-paths-config path/to/onediff/tests/comfyui/extra_model_paths.yaml ``` ## Usage Example diff --git a/onediff_comfy_nodes/benchmarks/resources/prompts.txt b/onediff_comfy_nodes/benchmarks/resources/prompts.txt index 9ac5b1c5d..1a81e76d4 100644 --- a/onediff_comfy_nodes/benchmarks/resources/prompts.txt +++ b/onediff_comfy_nodes/benchmarks/resources/prompts.txt @@ -77,4 +77,4 @@ line art, line style, 1girl, solo, japanese clothes, hand fan, kimono, black hai A boy with a sword in his hand takes a fighting stance on the high street with a sword slashed at the camera, close-up, pov view, first person, blue whirlwind flame, glow, surrealism, ultra-futurism, cyberpunk, 3D art, rich detail, best quality, centered Anime girl, masterpiece, best quality, hatsune miku, white gown, angel, angel wings, golden halo, dark background, upper body, closed mouth, looking at viewer, arms behind back, blue theme, night, highres, 4k, 8k, intricate detail, cinematic lighting, amazing quality, amazing shading, soft lighting, detailed Illustration, anime style, wallpaper. Super close-up shot of a weathered male skull almost buried by sand,side view,a fresh plant with two green leaves growing from the skull,detailed texture,shot from the botton,epic,super photorealism,cinematic,scenery,sunset,wasteland,desert,dune wave,super-detailed,highly realistic,8k,artistic,contrast lighting,vibrant color,hdr,erode -a (side view close-up half-body:1.85) fashion photoshoot photo of darth vader wearing a pink and white diamond studded outfit, his chest has a (very big CRT screen showing a pacman game:1.7), his helmet is made of a hello kitty themed white plastic, his helmet has sticker decals on it \ No newline at end of file +a (side view close-up half-body:1.85) fashion photoshoot photo of darth vader wearing a pink and white diamond studded outfit, his chest has a (very big CRT screen showing a pacman game:1.7), his helmet is made of a hello kitty themed white plastic, his helmet has sticker decals on it diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_IPAdapter_plus/ipadapter_advanced.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_IPAdapter_plus/ipadapter_advanced.json index 5aa32dbeb..32b5d9f2a 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_IPAdapter_plus/ipadapter_advanced.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_IPAdapter_plus/ipadapter_advanced.json @@ -177,4 +177,4 @@ "title": "Prep Image For ClipVision" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_InstantID/instantid_posed.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_InstantID/instantid_posed.json index 4e6d387a6..04872dad4 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_InstantID/instantid_posed.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/ComfyUI_InstantID/instantid_posed.json @@ -198,4 +198,4 @@ "title": "Load Checkpoint" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_4-Step_lightning.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_4-Step_lightning.json index 6cc6d093e..db0326d49 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_4-Step_lightning.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_4-Step_lightning.json @@ -184,4 +184,4 @@ "title": "LoraLoaderModelOnly" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json index 14c788c49..63af48303 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json @@ -219,4 +219,4 @@ "title": "Load Checkpoint" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora.json index 16afa17cc..3cd2c3bad 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora.json @@ -123,4 +123,4 @@ "title": "Load LoRA" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora_multiple.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora_multiple.json index a42d187fe..0a9ef55bd 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora_multiple.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/lora_multiple.json @@ -142,4 +142,4 @@ "title": "Load LoRA" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/sd3_basic.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/sd3_basic.json index 196f9bb1e..2e9648454 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/sd3_basic.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/sd3_basic.json @@ -170,4 +170,4 @@ "title": "VAE Decode" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json index 61f402d28..d915d1ef3 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/baseline/txt2img.json @@ -104,4 +104,4 @@ "title": "Load Checkpoint" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/example_workflow_api.json b/onediff_comfy_nodes/benchmarks/resources/workflows/example_workflow_api.json index a487d5cb1..0c9f919d2 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/example_workflow_api.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/example_workflow_api.json @@ -83,4 +83,4 @@ ] } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/nexfort/sd3_basic.json b/onediff_comfy_nodes/benchmarks/resources/workflows/nexfort/sd3_basic.json index 69a1a2873..1ce33c275 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/nexfort/sd3_basic.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/nexfort/sd3_basic.json @@ -228,4 +228,4 @@ "title": "VAE Speedup" } } - } \ No newline at end of file + } diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_IPAdapter_plus/ipadapter_advanced.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_IPAdapter_plus/ipadapter_advanced.json index e4bf7fe71..fef5d9f4b 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_IPAdapter_plus/ipadapter_advanced.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_IPAdapter_plus/ipadapter_advanced.json @@ -178,4 +178,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_InstantID/instantid_posed_speedup.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_InstantID/instantid_posed_speedup.json index f04381895..06f9cf9ca 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_InstantID/instantid_posed_speedup.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/ComfyUI_InstantID/instantid_posed_speedup.json @@ -199,4 +199,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_4-Step_lightning.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_4-Step_lightning.json index 3f501996a..28a9863a6 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_4-Step_lightning.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_4-Step_lightning.json @@ -185,4 +185,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json index 5ac93f6a6..fbeb81257 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/PuLID_ComfyUI/PuLID_IPAdapter_style_transfer.json @@ -220,4 +220,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json index 77002cb59..5d63c35a5 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json @@ -143,4 +143,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_speedup.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_speedup.json index 73f635ed3..7b0287008 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_speedup.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/lora_speedup.json @@ -124,4 +124,4 @@ "title": "Load LoRA" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/sdxl-control-lora-speedup.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/sdxl-control-lora-speedup.json index 5ca999356..874543e83 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/sdxl-control-lora-speedup.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/sdxl-control-lora-speedup.json @@ -150,4 +150,4 @@ "title": "Load ControlNet Model - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json index bd7ce14a4..0469c9dd4 100644 --- a/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json +++ b/onediff_comfy_nodes/benchmarks/resources/workflows/oneflow/txt2img.json @@ -105,4 +105,4 @@ "title": "Load Checkpoint - OneDiff" } } -} \ No newline at end of file +} diff --git a/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_ci.sh b/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_ci.sh index 8b1378917..e69de29bb 100644 --- a/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_ci.sh +++ b/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_ci.sh @@ -1 +0,0 @@ - diff --git a/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_local.sh b/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_local.sh index 0cd1af8b4..4ce5e7740 100644 --- a/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_local.sh +++ b/onediff_comfy_nodes/benchmarks/scripts/run_nexfort_case_local.sh @@ -11,7 +11,7 @@ python3 scripts/text_to_image.py \ -w $WORKFLOW_BASIC/sd3_basic.json \ --output-dir results \ --exp-name sd3_basic_baseline \ - --output-images + --output-images # Run the SD3 nexfort workflow python3 scripts/text_to_image.py \ @@ -30,4 +30,4 @@ python3 scripts/text_to_image.py \ --output-dir results \ --exp-name sd3_basic_nexfort_infer \ --output-images \ - --baseline-dir results/sd3_basic_baseline \ No newline at end of file + --baseline-dir results/sd3_basic_baseline diff --git a/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py b/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py index 7c8796e01..35ef2a582 100644 --- a/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py +++ b/onediff_comfy_nodes/benchmarks/scripts/text_to_image.py @@ -80,7 +80,10 @@ class ImageInfo(NamedTuple): class WorkflowProcessor: def __init__( - self, output_images: bool, output_dir: str, logger, + self, + output_images: bool, + output_dir: str, + logger, ): self.output_images = output_images self.output_dir = output_dir @@ -116,7 +119,9 @@ def run_workflow( logger.info(f"Result directory: {result_dir}") processor = WorkflowProcessor( - output_images, os.path.join(result_dir, "imgs"), logger, + output_images, + os.path.join(result_dir, "imgs"), + logger, ) result = {} @@ -172,7 +177,7 @@ def run_workflow( "basic_image_path": baseline_image_path, } ) - logger.info(f'SSIM: {ssim_value=}') + logger.info(f"SSIM: {ssim_value=}") assert ( ssim_value > ssim_threshold ), f"SSIM value {ssim_value} is not greater than the threshold {ssim_threshold}" diff --git a/onediff_comfy_nodes/benchmarks/src/core/service_client.py b/onediff_comfy_nodes/benchmarks/src/core/service_client.py index 494934805..d8e54ec5c 100644 --- a/onediff_comfy_nodes/benchmarks/src/core/service_client.py +++ b/onediff_comfy_nodes/benchmarks/src/core/service_client.py @@ -15,7 +15,9 @@ class ComfyGraph: def __init__( - self, graph: dict, sampler_nodes: list[str], + self, + graph: dict, + sampler_nodes: list[str], ): self.graph = graph self.sampler_nodes = sampler_nodes @@ -30,7 +32,8 @@ def set_prompt(self, prompt, negative_prompt=None): self.graph[negative_prompt_node]["inputs"]["text"] = negative_prompt def set_sampler_name( - self, sampler_name: str, + self, + sampler_name: str, ): # sets the sampler name for the sampler nodes (eg. base and refiner) for node in self.sampler_nodes: diff --git a/onediff_comfy_nodes/benchmarks/src/input_registration.py b/onediff_comfy_nodes/benchmarks/src/input_registration.py index d97307532..425b18c30 100644 --- a/onediff_comfy_nodes/benchmarks/src/input_registration.py +++ b/onediff_comfy_nodes/benchmarks/src/input_registration.py @@ -1,6 +1,7 @@ import json import os from typing import NamedTuple + from core.registry import create_generator_registry from core.service_client import ComfyGraph @@ -21,7 +22,6 @@ ] - class InputParams(NamedTuple): graph: ComfyGraph @@ -81,6 +81,8 @@ def _(workflow_path, *args, **kwargs): f"{WORKFLOW_DIR}/baseline/sd3_basic.json", f"{WORKFLOW_DIR}/nexfort/sd3_basic.json", ] + + @register_generator(SD3_WORKFLOWS) def _(workflow_path, *args, **kwargs): with open(workflow_path, "r") as fp: diff --git a/onediff_comfy_nodes/docs/ComfyUI_Online_Quantization.md b/onediff_comfy_nodes/docs/ComfyUI_Online_Quantization.md index 560d894e3..028dc02fa 100644 --- a/onediff_comfy_nodes/docs/ComfyUI_Online_Quantization.md +++ b/onediff_comfy_nodes/docs/ComfyUI_Online_Quantization.md @@ -134,7 +134,7 @@ Note that you can download all images in this page and then drag or load them on Model parameters can be referred to [Parameter Description](#parameter-description). -#### Download the required model files +#### Download the required model files diff --git a/onediff_comfy_nodes/docs/ControlNet/README.md b/onediff_comfy_nodes/docs/ControlNet/README.md index 38ef0e7e3..6582f03c3 100644 --- a/onediff_comfy_nodes/docs/ControlNet/README.md +++ b/onediff_comfy_nodes/docs/ControlNet/README.md @@ -7,7 +7,7 @@
- + **End2End Time** , Image Size 512x512 , Batch Size 4 , steps 20 @@ -31,14 +31,14 @@ Replace `"Load ControlNet Model"` with `"Load ControlNet Model - OneDiff"` in co ![ControlNet](./controlnet_onediff.png) #### Quantization ![ControlNet](./controlnet_onediff_quant.png) -#### Mixing ControlNet +#### Mixing ControlNet ![ControlNet](./mixing_controlnets.png) ## FAQ -- Q: RuntimeError: After graph built, the device of graph can't be modified, current device: cuda:0, target device: cpu +- Q: RuntimeError: After graph built, the device of graph can't be modified, current device: cuda:0, target device: cpu - Please use `--gpu-only` when launching comfyui, for example, `python main.py --gpu-only`. @@ -46,7 +46,7 @@ Replace `"Load ControlNet Model"` with `"Load ControlNet Model - OneDiff"` in co - To initiate a fresh run, delete the files within the `ComfyUI/input/graphs/` directory and then proceed with rerunning the process. - **Switching the strength parameter between 0 and > 0 in the "Apply ControlNet" node is not supported.** A strength of 0 implies not using ControlNet, while a strength greater than 0 activates ControlNet. This may lead to changes in the graph structure, resulting in errors. -- Q: Acceleration of ControlNet: Not very apparent +- Q: Acceleration of ControlNet: Not very apparent - ControlNet is a very small model, In prior tests, the iteration ratio between UNet and ControlNet was 2:1. - UNet compilation contributed to a 30% acceleration, while ControlNet contributed 15%, resulting in an overall acceleration of approximately 45%. - In the enterprise edition, UNet exhibited a more substantial acceleration, making the acceleration from ControlNet relatively smaller. diff --git a/onediff_comfy_nodes/docs/OnlineQuantization.md b/onediff_comfy_nodes/docs/OnlineQuantization.md index 5e237bfb6..38fc27b23 100644 --- a/onediff_comfy_nodes/docs/OnlineQuantization.md +++ b/onediff_comfy_nodes/docs/OnlineQuantization.md @@ -21,7 +21,7 @@ Notes: ## Performance Comparison ### [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) -Updated on Mon 08 Apr 2024 +Updated on Mon 08 Apr 2024 ![quant_sdxl](https://github.com/siliconflow/onediff/assets/109639975/b8f8da75-944b-4553-aea3-69c19886af37) @@ -48,4 +48,4 @@ Updated on Mon 08 Apr 2024 - oneflow `python -m oneflow --doctor`: {git_commit: 4ed3138, version: 0.9.1.dev20240402+cu122, enterprise: True} - ComfyUI Tue Apr 9 commit: 4201181b35402e0a992b861f8d2f0e0b267f52fa - Start comfyui command: `python main.py --gpu-only` -- Python 3.10.13 \ No newline at end of file +- Python 3.10.13 diff --git a/onediff_comfy_nodes/docs/SVD/README.md b/onediff_comfy_nodes/docs/SVD/README.md index 659edf9c6..09e049c0b 100644 --- a/onediff_comfy_nodes/docs/SVD/README.md +++ b/onediff_comfy_nodes/docs/SVD/README.md @@ -28,7 +28,7 @@ Only supported on the Linux platform. Test(Updated on March 1st, 2023) - Python 3.10.13 - NVIDIA A800-SXM4-80GB -- ComfyUI Thu Feb 29 commit: cb7c3a2921cfc0805be0229b4634e1143d60e6fe +- ComfyUI Thu Feb 29 commit: cb7c3a2921cfc0805be0229b4634e1143d60e6fe ## Contact diff --git a/onediff_comfy_nodes/docs/lora.md b/onediff_comfy_nodes/docs/lora.md index dffe0d603..9c2fb1b90 100644 --- a/onediff_comfy_nodes/docs/lora.md +++ b/onediff_comfy_nodes/docs/lora.md @@ -15,4 +15,4 @@ Notes - Ensure all necessary dependencies are installed as per the guide. - Use the --gpu-only parameter to ensure the program runs on the GPU. -[wokflow file](../benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json) \ No newline at end of file +[wokflow file](../benchmarks/resources/workflows/oneflow/lora_multiple_speedup.json) diff --git a/onediff_comfy_nodes/docs/sd3/README.md b/onediff_comfy_nodes/docs/sd3/README.md index 9fafc4a8d..2d155b566 100644 --- a/onediff_comfy_nodes/docs/sd3/README.md +++ b/onediff_comfy_nodes/docs/sd3/README.md @@ -1,16 +1,16 @@ ## Accelerate SD3 by using onediff -huggingface: https://huggingface.co/stabilityai/stable-diffusion-3-medium +huggingface: https://huggingface.co/stabilityai/stable-diffusion-3-medium ## Environment setup ### Set UP requirements ```shell -# python 3.10 +# python 3.10 COMFYUI_DIR=$pwd/ComfyUI # install ComfyUI git clone https://github.com/comfyanonymous/ComfyUI.git # install onediff & onediff_comfy_nodes -git clone https://github.com/siliconflow/onediff.git +git clone https://github.com/siliconflow/onediff.git cd onediff && pip install -r onediff_comfy_nodes/sd3/requirements.txt && pip install -e . ln -s $pwd/onediff/onediff_comfy_nodes $COMFYUI_DIR/custom_nodes ``` @@ -41,7 +41,7 @@ with torch.inference_mode(): options={"mode": "max-autotune:cudagraphs", "dynamic": True, "fullgraph": True}, ) print(compiled_mod(torch.randn(10, 100, device="cuda").half()).shape) - + print("Successfully installed~") ``` @@ -55,11 +55,11 @@ print("Successfully installed~") ```shell export ACCESS_TOKEN="User Access Tokens" wget --header="Authorization: Bearer $ACCESS_TOKEN" \ -https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium.safetensors -O models/checkpoints/sd3_medium.safetensors +https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/sd3_medium.safetensors -O models/checkpoints/sd3_medium.safetensors wget --header="Authorization: Bearer $ACCESS_TOKEN" \ https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors -O models/clip/clip_g.safetensors - + wget --header="Authorization: Bearer $ACCESS_TOKEN" \ https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors -O models/clip/clip_l.safetensors @@ -92,8 +92,8 @@ Here is a very basic example how to use it: ## Performance Comparison -- Testing on NVIDIA GeForce RTX 4090, with image size of 1024*1024, iterating 28 steps. -- OneDiff[Nexfort] Compile mode: +- Testing on NVIDIA GeForce RTX 4090, with image size of 1024*1024, iterating 28 steps. +- OneDiff[Nexfort] Compile mode: `max-optimize:max-autotune:low-precision` diff --git a/onediff_comfy_nodes/extras_nodes/nodes_compare.py b/onediff_comfy_nodes/extras_nodes/nodes_compare.py index a06fa9edb..a1bec4d3b 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_compare.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_compare.py @@ -4,8 +4,10 @@ import folder_paths import numpy as np -import oneflow as flow -from onediff.infer_compiler.backends.oneflow.transform.builtin_transform import torch2oflow +import oneflow as flow # usort: skip +from onediff.infer_compiler.backends.oneflow.transform.builtin_transform import ( + torch2oflow, +) from PIL import Image try: @@ -66,15 +68,19 @@ def compare(self, torch_model, oneflow_model, check): ) return {} - removeprefix = lambda ss, prefix: ss[len(prefix):] if ss.startswith(prefix) else ss - + removeprefix = ( + lambda ss, prefix: ss[len(prefix) :] if ss.startswith(prefix) else ss + ) + cnt = 0 for key, _ in oflow_unet.named_parameters(): key = removeprefix(key, "_deployable_module_model._torch_module.") torch_value = torch_unet.get_parameter(key).cuda() - oflow_value = oflow_unet._deployable_module_model._oneflow_module.get_parameter( - key - ).cuda() + oflow_value = ( + oflow_unet._deployable_module_model._oneflow_module.get_parameter( + key + ).cuda() + ) if not flow.allclose(torch2oflow(torch_value), oflow_value, 1e-4, 1e-4): print( @@ -144,7 +150,10 @@ def save_images( subfolder, filename_prefix, ) = folder_paths.get_save_image_path( - filename_prefix, self.output_dir, images1[0].shape[1], images1[0].shape[0], + filename_prefix, + self.output_dir, + images1[0].shape[1], + images1[0].shape[0], ) results = list() for image1, image2 in zip(images1, images2): diff --git a/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py index 10bc22e41..3688d526d 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_nexfort_booster.py @@ -1,4 +1,5 @@ import collections + from ..modules.nexfort.booster_basic import BasicNexFortBoosterExecutor diff --git a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py index b4a1b3297..90edbe588 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py @@ -6,12 +6,11 @@ import torch from comfy import model_management from comfy.cli_args import args - -from onediff.utils.import_utils import is_onediff_quant_available from onediff.infer_compiler.backends.oneflow.utils.version_util import ( is_community_version, ) +from onediff.utils.import_utils import is_onediff_quant_available from ..modules import BoosterScheduler from ..modules.oneflow import ( @@ -20,8 +19,7 @@ PatchBoosterExecutor, ) from ..modules.oneflow.config import ONEDIFF_QUANTIZED_OPTIMIZED_MODELS -from ..modules.oneflow.utils import OUTPUT_FOLDER, load_graph, save_graph -from ..modules import BoosterScheduler +from ..modules.oneflow.utils import load_graph, OUTPUT_FOLDER, save_graph if is_onediff_quant_available() and not is_community_version(): from ..modules.oneflow.booster_quantization import ( @@ -138,7 +136,12 @@ def INPUT_TYPES(s): ), "end_step": ( "INT", - {"default": 1000, "min": 0, "max": 1000, "step": 0.1,}, + { + "default": 1000, + "min": 0, + "max": 1000, + "step": 0.1, + }, ), }, } @@ -158,7 +161,9 @@ def deep_cache_convert( start_step, end_step, ): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) booster = BoosterScheduler( DeepcacheBoosterExecutor( cache_interval=cache_interval, @@ -289,7 +294,12 @@ def INPUT_TYPES(s): ), "end_step": ( "INT", - {"default": 1000, "min": 0, "max": 1000, "step": 0.1,}, + { + "default": 1000, + "min": 0, + "max": 1000, + "step": 0.1, + }, ), } } @@ -309,7 +319,9 @@ def onediff_load_checkpoint( start_step=0, end_step=1000, ): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) # CheckpointLoaderSimple.load_checkpoint modelpatcher, clip, vae = self.load_checkpoint(ckpt_name) booster = BoosterScheduler( @@ -337,7 +349,10 @@ class BatchSizePatcher: @classmethod def INPUT_TYPES(s): return { - "required": {"model": ("MODEL",), "latent_image": ("LATENT",),}, + "required": { + "model": ("MODEL",), + "latent_image": ("LATENT",), + }, } RETURN_TYPES = ("MODEL",) @@ -366,7 +381,9 @@ def INPUT_TYPES(s): }, ), }, - "optional": {"custom_booster": ("CUSTOM_BOOSTER",),}, + "optional": { + "custom_booster": ("CUSTOM_BOOSTER",), + }, } RETURN_TYPES = ("MODEL",) @@ -381,7 +398,9 @@ def speedup( cache_name="svd", custom_booster: BoosterScheduler = None, ): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) if custom_booster: booster = custom_booster booster.inplace = inplace @@ -403,7 +422,10 @@ def INPUT_TYPES(s): if os.path.isfile(os.path.join(vae_folder, f)) and f.endswith(".graph") ] return { - "required": {"vae": ("VAE",), "graph": (sorted(graph_files),),}, + "required": { + "vae": ("VAE",), + "graph": (sorted(graph_files),), + }, } RETURN_TYPES = ("VAE",) @@ -411,7 +433,9 @@ def INPUT_TYPES(s): CATEGORY = "OneDiff" def load_graph(self, vae, graph): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) vae_model = vae.first_stage_model device = model_management.vae_offload_device() load_graph(vae_model, graph, device, subfolder="vae") @@ -435,7 +459,9 @@ def INPUT_TYPES(s): OUTPUT_NODE = True def save_graph(self, images, vae, filename_prefix): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) vae_model = vae.first_stage_model vae_device = model_management.vae_offload_device() save_graph(vae_model, filename_prefix, vae_device, subfolder="vae") @@ -453,7 +479,10 @@ def INPUT_TYPES(s): if os.path.isfile(os.path.join(unet_folder, f)) and f.endswith(".graph") ] return { - "required": {"model": ("MODEL",), "graph": (sorted(graph_files),),}, + "required": { + "model": ("MODEL",), + "graph": (sorted(graph_files),), + }, } RETURN_TYPES = ("MODEL",) @@ -461,7 +490,9 @@ def INPUT_TYPES(s): CATEGORY = "OneDiff" def load_graph(self, model, graph): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) diffusion_model = model.model.diffusion_model @@ -486,7 +517,9 @@ def INPUT_TYPES(s): OUTPUT_NODE = True def save_graph(self, samples, model, filename_prefix): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) diffusion_model = model.model.diffusion_model save_graph(diffusion_model, filename_prefix, "cuda", subfolder="unet") return {} @@ -532,7 +565,11 @@ def INPUT_TYPES(cls): if "calibrate_info.txt" in files: paths.append(os.path.relpath(root, start=search_path)) - return {"required": {"model_path": (paths,),}} + return { + "required": { + "model_path": (paths,), + } + } RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet_int8" @@ -540,7 +577,9 @@ def INPUT_TYPES(cls): CATEGORY = "OneDiff" def load_unet_int8(self, model_path): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) from ..modules.oneflow.utils.onediff_quant_utils import ( replace_module_with_quantizable_module, ) @@ -579,7 +618,9 @@ def INPUT_TYPES(s): OUTPUT_NODE = True def quantize_model(self, model, output_dir, conv, linear): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) from ..modules.oneflow.utils import quantize_and_save_model diffusion_model = model.model.diffusion_model @@ -608,10 +649,11 @@ def INPUT_TYPES(s): CATEGORY = "OneDiff/Loaders" FUNCTION = "onediff_load_checkpoint" - def onediff_load_checkpoint(self, ckpt_name, vae_speedup): modelpatcher, clip, vae = self.load_checkpoint(ckpt_name) - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) booster = BoosterScheduler( OnelineQuantizationBoosterExecutor( conv_percentage=100, @@ -659,10 +701,16 @@ def INPUT_TYPES(s): FUNCTION = "onediff_load_checkpoint" def onediff_load_checkpoint( - self, ckpt_name, model_path, compile, vae_speedup, + self, + ckpt_name, + model_path, + compile, + vae_speedup, ): need_compile = compile == "enable" - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) modelpatcher, clip, vae = self.load_checkpoint(ckpt_name) # TODO fix by op.compile @@ -715,7 +763,9 @@ def onediff_load_checkpoint( output_vae=True, output_clip=True, ): - print(f'Warning: {type(self).__name__} will be deleted. Please use it with caution.') + print( + f"Warning: {type(self).__name__} will be deleted. Please use it with caution." + ) modelpatcher, clip, vae = self.load_checkpoint( ckpt_name, output_vae, output_clip ) diff --git a/onediff_comfy_nodes/extras_nodes/nodes_prompt_styler.py b/onediff_comfy_nodes/extras_nodes/nodes_prompt_styler.py index fa4f4becf..14790d7fa 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_prompt_styler.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_prompt_styler.py @@ -1,7 +1,9 @@ import json import os + # Prompt Styler, a custom node for ComfyUI + def read_json_file(file_path): """ Reads a JSON file's content and returns it. @@ -12,17 +14,23 @@ def read_json_file(file_path): return None try: - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: content = json.load(file) # Check if the content matches the expected format. - if not all(['name' in item and 'prompt' in item and 'negative_prompt' in item for item in content]): + if not all( + [ + "name" in item and "prompt" in item and "negative_prompt" in item + for item in content + ] + ): print(f"Warning: Invalid content in file {file_path}") return None return content except Exception as e: print(f"An error occurred while reading {file_path}: {str(e)}") return None - + + def read_sdxl_styles(json_data): """ Returns style names from the provided JSON data. @@ -31,10 +39,18 @@ def read_sdxl_styles(json_data): print("Error: input data must be a list") return [] - return [item['name'] for item in json_data if isinstance(item, dict) and 'name' in item] + return [ + item["name"] for item in json_data if isinstance(item, dict) and "name" in item + ] + def get_all_json_files(directory): - return [os.path.join(directory, file) for file in os.listdir(directory) if file.endswith('.json') and os.path.isfile(os.path.join(directory, file))] + return [ + os.path.join(directory, file) + for file in os.listdir(directory) + if file.endswith(".json") and os.path.isfile(os.path.join(directory, file)) + ] + def load_styles_from_directory(directory): """ @@ -49,26 +65,31 @@ def load_styles_from_directory(directory): json_data = read_json_file(json_file) if json_data: for item in json_data: - original_style = item['name'] + original_style = item["name"] style = original_style suffix = 1 while style in seen: style = f"{original_style}_{suffix}" suffix += 1 - item['name'] = style + item["name"] = style seen.add(style) combined_data.append(item) - unique_style_names = [item['name'] for item in combined_data if isinstance(item, dict) and 'name' in item] - + unique_style_names = [ + item["name"] + for item in combined_data + if isinstance(item, dict) and "name" in item + ] + return combined_data, unique_style_names + def find_template_by_name(json_data, template_name): """ Returns a template from the JSON data by name or None if not found. """ for template in json_data: - if template['name'] == template_name: + if template["name"] == template_name: return template return None @@ -76,54 +97,94 @@ def find_template_by_name(json_data, template_name): class CLIPTextEncodePromptStyle: @classmethod def INPUT_TYPES(s): - style_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "style_template") + style_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "style_template" + ) s.json_data, styles = load_styles_from_directory(style_dir) - return {"required": {"text_positive": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "text_negative": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "style": ((styles), ), - "clip": ("CLIP", )}, - "optional":{ - "log_prompt": ("BOOLEAN", {"default": True, "label_on": "yes", "label_off": "no"}), - "style_positive": ("BOOLEAN", {"default": True, "label_on": "yes", "label_off": "no"}), - "style_negative": ("BOOLEAN", {"default": True, "label_on": "yes", "label_off": "no"}), - } - } - RETURN_TYPES = ("CONDITIONING","CONDITIONING", ) - RETURN_NAMES = ("positive", "negative", ) + return { + "required": { + "text_positive": ( + "STRING", + {"multiline": True, "dynamicPrompts": True}, + ), + "text_negative": ( + "STRING", + {"multiline": True, "dynamicPrompts": True}, + ), + "style": ((styles),), + "clip": ("CLIP",), + }, + "optional": { + "log_prompt": ( + "BOOLEAN", + {"default": True, "label_on": "yes", "label_off": "no"}, + ), + "style_positive": ( + "BOOLEAN", + {"default": True, "label_on": "yes", "label_off": "no"}, + ), + "style_negative": ( + "BOOLEAN", + {"default": True, "label_on": "yes", "label_off": "no"}, + ), + }, + } + + RETURN_TYPES = ( + "CONDITIONING", + "CONDITIONING", + ) + RETURN_NAMES = ( + "positive", + "negative", + ) FUNCTION = "encode" CATEGORY = "OneDiff/Conditioning" - def encode(self, clip, text_positive, text_negative, style, log_prompt, style_positive, style_negative): + def encode( + self, + clip, + text_positive, + text_negative, + style, + log_prompt, + style_positive, + style_negative, + ): template = find_template_by_name(self.json_data, style) if style_positive: - positive_prompt = template['prompt'].replace('{prompt}', text_positive) + positive_prompt = template["prompt"].replace("{prompt}", text_positive) else: positive_prompt = text_positive - - positive_tokens = clip.tokenize(positive_prompt) - positive_cond, positive_pooled = clip.encode_from_tokens(positive_tokens, return_pooled=True) + positive_tokens = clip.tokenize(positive_prompt) + positive_cond, positive_pooled = clip.encode_from_tokens( + positive_tokens, return_pooled=True + ) if style_negative: - negative_prompt_template = template.get('negative_prompt', "") - negative_prompt = f"{negative_prompt_template}, {text_negative}" if negative_prompt_template and text_negative else text_negative or negative_prompt_template + negative_prompt_template = template.get("negative_prompt", "") + negative_prompt = ( + f"{negative_prompt_template}, {text_negative}" + if negative_prompt_template and text_negative + else text_negative or negative_prompt_template + ) else: negative_prompt = text_negative negative_tokens = clip.tokenize(negative_prompt) - negative_cond, negative_pooled = clip.encode_from_tokens(negative_tokens, return_pooled=True) + negative_cond, negative_pooled = clip.encode_from_tokens( + negative_tokens, return_pooled=True + ) if log_prompt: - print(f'{positive_prompt=}\n{negative_prompt=}') - return ([[positive_cond, {"pooled_output": positive_pooled}]], [[negative_cond, {"pooled_output": negative_pooled}]]) - - - - - - + print(f"{positive_prompt=}\n{negative_prompt=}") + return ( + [[positive_cond, {"pooled_output": positive_pooled}]], + [[negative_cond, {"pooled_output": negative_pooled}]], + ) NODE_CLASS_MAPPINGS = { diff --git a/onediff_comfy_nodes/extras_nodes/nodes_torch_compile_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_torch_compile_booster.py index 6a7b72558..8c467329e 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_torch_compile_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_torch_compile_booster.py @@ -1,4 +1,7 @@ -from ..modules.torch_compile.booster_basic import TorchCompileBoosterExecutor # type: ignore +from ..modules.torch_compile.booster_basic import ( # type: ignore + TorchCompileBoosterExecutor, +) + class OneDiffTorchCompileBooster: @classmethod diff --git a/onediff_comfy_nodes/extras_nodes/style_template/SDXL_InstantID_styles.json b/onediff_comfy_nodes/extras_nodes/style_template/SDXL_InstantID_styles.json index 270619e4e..28ef94ec9 100644 --- a/onediff_comfy_nodes/extras_nodes/style_template/SDXL_InstantID_styles.json +++ b/onediff_comfy_nodes/extras_nodes/style_template/SDXL_InstantID_styles.json @@ -424,4 +424,4 @@ "prompt": "Tilt-shift photo of {prompt} . Selective focus, miniature effect, blurred background, highly detailed, vibrant, perspective control", "negative_prompt": "blurry, noisy, deformed, flat, low contrast, unrealistic, oversaturated, underexposed" } -] \ No newline at end of file +] diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index db64a5917..f50704717 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -1,7 +1,8 @@ -import torch from collections import OrderedDict -from comfy.model_patcher import ModelPatcher from functools import singledispatch + +import torch +from comfy.model_patcher import ModelPatcher from comfy.sd import VAE from onediff.torch_utils.module_operations import get_sub_module diff --git a/onediff_comfy_nodes/modules/booster_interface.py b/onediff_comfy_nodes/modules/booster_interface.py index 2abacfb59..31abb1834 100644 --- a/onediff_comfy_nodes/modules/booster_interface.py +++ b/onediff_comfy_nodes/modules/booster_interface.py @@ -1,7 +1,7 @@ # import os +import dataclasses import uuid from abc import ABC, abstractmethod -import dataclasses # from functools import singledispatchmethod # from typing import Optional diff --git a/onediff_comfy_nodes/modules/booster_scheduler.py b/onediff_comfy_nodes/modules/booster_scheduler.py index 6b5c29260..2a1c33f50 100644 --- a/onediff_comfy_nodes/modules/booster_scheduler.py +++ b/onediff_comfy_nodes/modules/booster_scheduler.py @@ -1,10 +1,12 @@ import copy -import torch.nn as nn from functools import singledispatchmethod, wraps from typing import List + +import torch.nn as nn +from comfy import model_management from comfy.model_patcher import ModelPatcher from comfy.sd import VAE -from comfy import model_management + from .booster_cache import BoosterCacheService from .booster_interface import BoosterExecutor, BoosterSettings diff --git a/onediff_comfy_nodes/modules/nexfort/README.md b/onediff_comfy_nodes/modules/nexfort/README.md index 8e347758e..67550193c 100644 --- a/onediff_comfy_nodes/modules/nexfort/README.md +++ b/onediff_comfy_nodes/modules/nexfort/README.md @@ -1,4 +1,4 @@ Start comfyui command ```shell python main.py --gpu-only --disable-cuda-malloc -``` \ No newline at end of file +``` diff --git a/onediff_comfy_nodes/modules/nexfort/__init__.py b/onediff_comfy_nodes/modules/nexfort/__init__.py index dbb1deb21..728d93e08 100644 --- a/onediff_comfy_nodes/modules/nexfort/__init__.py +++ b/onediff_comfy_nodes/modules/nexfort/__init__.py @@ -1,8 +1,9 @@ import os -from .hijack_samplers import samplers_hijack + +from .hijack_comfyui_instantid import comfyui_instantid_hijacker from .hijack_ipadapter_plus import ipadapter_plus_hijacker from .hijack_pulid_comfyui import pulid_comfyui_hijacker -from .hijack_comfyui_instantid import comfyui_instantid_hijacker +from .hijack_samplers import samplers_hijack samplers_hijack.hijack(last=False) ipadapter_plus_hijacker.hijack(last=False) diff --git a/onediff_comfy_nodes/modules/nexfort/booster_basic.py b/onediff_comfy_nodes/modules/nexfort/booster_basic.py index dcef28fa0..3f390f9ab 100644 --- a/onediff_comfy_nodes/modules/nexfort/booster_basic.py +++ b/onediff_comfy_nodes/modules/nexfort/booster_basic.py @@ -1,15 +1,17 @@ -import torch from functools import partial, singledispatchmethod from typing import Optional +import torch + from comfy.controlnet import ControlLora, ControlNet from comfy.model_patcher import ModelPatcher from comfy.sd import VAE +from nexfort.utils.memory_format import apply_memory_format from onediff.infer_compiler import compile -from nexfort.utils.memory_format import apply_memory_format -from .onediff_controlnet import OneDiffControlLora + from ..booster_interface import BoosterExecutor +from .onediff_controlnet import OneDiffControlLora class BasicNexFortBoosterExecutor(BoosterExecutor): diff --git a/onediff_comfy_nodes/modules/nexfort/booster_utils.py b/onediff_comfy_nodes/modules/nexfort/booster_utils.py index db00ad99f..b7e4ab5b3 100644 --- a/onediff_comfy_nodes/modules/nexfort/booster_utils.py +++ b/onediff_comfy_nodes/modules/nexfort/booster_utils.py @@ -1,8 +1,8 @@ +from comfy.model_base import BaseModel +from comfy.model_patcher import ModelPatcher from onediff.infer_compiler.backends.nexfort.deployable_module import ( NexfortDeployableModule as DeployableModule, ) -from comfy.model_patcher import ModelPatcher -from comfy.model_base import BaseModel def clear_deployable_module_cache_and_unbind(*args, **kwargs): diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py index bdfc9a64f..06a5a549c 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/InstantID.py @@ -1,6 +1,6 @@ from ..booster_utils import is_using_nexfort_backend -from ._config import comfyui_instantid_hijacker,comfyui_instantid from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace +from ._config import comfyui_instantid, comfyui_instantid_hijacker set_model_patch_replace_fn_pt = comfyui_instantid.InstantID._set_model_patch_replace diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md index 6cc18c45e..0fe808800 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/README.md @@ -14,7 +14,7 @@ git clone https://github.com/comfyanonymous/ComfyUI git reset --hard 2d4164271634476627aae31fbec251ca748a0ae0 ``` When you have completed these steps, follow the [instructions](https://github.com/comfyanonymous/ComfyUI) to install ComfyUI - + #### Install ComfyUI_InstantID ``` @@ -26,12 +26,12 @@ When you have completed these steps,follow the [instructions](https://github.com ### Quick Start -> Recommend running the official example of ComfyUI_InstantID now, and then trying OneDiff acceleration. +> Recommend running the official example of ComfyUI_InstantID now, and then trying OneDiff acceleration. > You can Load these images in ComfyUI to get the full workflow. Experiment (GeForce RTX 3090) Workflow for OneDiff Acceleration in ComfyUI_InstantID: -1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. +1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. 2. Add a **`Batch Size Patcher`** node before the **`Ksampler`** node (due to temporary lack of support for dynamic batch size). As follows: ![workflow (20)](https://github.com/siliconflow/onediff/assets/117806079/492a83a8-1a5b-4fb3-9e53-6d53e881a3f8) @@ -120,4 +120,4 @@ For users of OneDiff Community, please visit [GitHub Issues](https://github.com/ For users of OneDiff Enterprise, you can contact contact@siliconflow.com for commercial support. -Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. \ No newline at end of file +Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py index d6decc38b..62c8b4eb1 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_comfyui_instantid/_config.py @@ -3,6 +3,7 @@ COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") from onediff.utils.import_utils import DynamicModuleLoader + from ...sd_hijack_utils import Hijacker __all__ = ["comfyui_instantid"] diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py index b3a16a2b8..e04f672e6 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/CrossAttentionPatch.py @@ -1,5 +1,6 @@ -import torch import math + +import torch import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py index 49468218a..1ceb2151c 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/IPAdapterPlus.py @@ -1,6 +1,6 @@ """hijack ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/IPAdapterPlus.py""" from ..booster_utils import is_using_nexfort_backend -from ._config import ipadapter_plus_hijacker, ipadapter_plus +from ._config import ipadapter_plus, ipadapter_plus_hijacker from .set_model_patch_replace import set_model_patch_replace set_model_patch_replace_fn = ipadapter_plus.IPAdapterPlus.set_model_patch_replace diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py index fe1a737ca..35b58f57d 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/_config.py @@ -3,6 +3,7 @@ COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") from onediff.utils.import_utils import DynamicModuleLoader + from ...sd_hijack_utils import Hijacker __all__ = ["ipadapter_plus", "ipadapter_plus_hijacker"] diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py index 30e130169..c3c801ab7 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_ipadapter_plus/set_model_patch_replace.py @@ -1,9 +1,10 @@ import torch from comfy import model_management -from .CrossAttentionPatch import Attn2Replace, ipadapter_attention -from ..patch_management import create_patch_executor, PatchType from ..booster_utils import clear_deployable_module_cache_and_unbind +from ..patch_management import create_patch_executor, PatchType + +from .CrossAttentionPatch import Attn2Replace, ipadapter_attention def set_model_patch_replace( @@ -44,7 +45,7 @@ def split_patch_kwargs(patch_kwargs): split1dict[k] = v else: split2dict[k] = v - + # patch for weight # weight = split1dict["weight"] # if isinstance(weight, (int, float)): diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py b/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py index 77b10e75c..4ca0987b4 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_model_patcher.py @@ -1,8 +1,8 @@ from comfy.model_patcher import ModelPatcher from ..sd_hijack_utils import Hijacker -from .patch_management import PatchType, create_patch_executor from .booster_utils import is_using_nexfort_backend +from .patch_management import create_patch_executor, PatchType def clone_nexfort(org_fn, self, *args, **kwargs): diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py index e48ca66e0..fc298345e 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/_config.py @@ -3,6 +3,7 @@ COMFYUI_ROOT = os.getenv("COMFYUI_ROOT") from onediff.utils.import_utils import DynamicModuleLoader + from ...sd_hijack_utils import Hijacker __all__ = ["pulid_comfyui", "pulid_comfyui_hijacker"] diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py index 7dd7e11ca..d35328bbc 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_pulid_comfyui/pulid.py @@ -1,8 +1,9 @@ from functools import partial -from ._config import pulid_comfyui, pulid_comfyui_hijacker + from ..booster_utils import is_using_nexfort_backend from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace +from ._config import pulid_comfyui, pulid_comfyui_hijacker # ComfyUI/custom_nodes/PuLID_ComfyUI/pulid.py set_model_patch_replace_fn = pulid_comfyui.pulid.set_model_patch_replace diff --git a/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py b/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py index 5a2af0a56..3d11c4166 100644 --- a/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py +++ b/onediff_comfy_nodes/modules/nexfort/hijack_samplers.py @@ -4,12 +4,13 @@ """ from typing import Dict + import torch from comfy.samplers import calc_cond_batch, can_concat_cond, cond_cat, get_area_and_mult from ..sd_hijack_utils import Hijacker -from .patch_management import PatchType, create_patch_executor from .booster_utils import is_using_nexfort_backend +from .patch_management import create_patch_executor, PatchType def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): @@ -103,22 +104,31 @@ def calc_cond_batch_of(orig_func, model, conds, x_in, timestep, model_options): ): patch_executor = create_patch_executor(PatchType.UNetExtraInputOptions) extra_options = transformer_options - sigma = extra_options["sigmas"][0].item() if 'sigmas' in extra_options else 999999999.9 + sigma = ( + extra_options["sigmas"][0].item() + if "sigmas" in extra_options + else 999999999.9 + ) assert "_sigmas" not in extra_options extra_options["_sigmas"] = {} - attn2_patch_dict = extra_options['patches_replace']["attn2"] + attn2_patch_dict = extra_options["patches_replace"]["attn2"] for k, attn_m in attn2_patch_dict.items(): out_lst = [] for i, callback in enumerate(attn_m.callback): - if sigma <= attn_m.kwargs[i]["sigma_start"] and sigma >= attn_m.kwargs[i]["sigma_end"]: + if ( + sigma <= attn_m.kwargs[i]["sigma_start"] + and sigma >= attn_m.kwargs[i]["sigma_end"] + ): out_lst.append(1) else: out_lst.append(0) # extra inputs - transformer_options["_sigmas"][attn_m.forward_patch_key] = torch.randn(*out_lst) + transformer_options["_sigmas"][attn_m.forward_patch_key] = torch.randn( + *out_lst + ) # extra inputs - transformer_options["_attn2"] = patch_executor.get_patch(diff_model)[ + transformer_options["_attn2"] = patch_executor.get_patch(diff_model)[ "attn2" ] @@ -184,5 +194,7 @@ def cond_func(orig_func, model, *args, **kwargs): samplers_hijack = Hijacker() samplers_hijack.register( - orig_func=calc_cond_batch, sub_func=calc_cond_batch_of, cond_func=cond_func, + orig_func=calc_cond_batch, + sub_func=calc_cond_batch_of, + cond_func=cond_func, ) diff --git a/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py b/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py index 919f9008d..0fe93bba5 100644 --- a/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py +++ b/onediff_comfy_nodes/modules/nexfort/onediff_controlnet.py @@ -1,4 +1,5 @@ import inspect + import comfy from comfy.controlnet import ControlLora, ControlLoraOps, ControlNet diff --git a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py index faa2c1ea8..d02336f94 100644 --- a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py +++ b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_executor.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod from typing import Dict, List -from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel +from comfy.model_patcher import ModelPatcher + class PatchExecutorBase(ABC): @abstractmethod @@ -108,9 +109,7 @@ def is_use_deep_cache_unet(self, module: BaseModel): class UNetExtraInputOptions(PatchExecutorBase): def __init__(self) -> None: - """UNetExtraInputOptions - - """ + """UNetExtraInputOptions""" super().__init__() self.patch_name = type(self).__name__ diff --git a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py index 5f7e38c33..51fa29b00 100644 --- a/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py +++ b/onediff_comfy_nodes/modules/nexfort/patch_management/patch_factory.py @@ -1,10 +1,11 @@ from enum import Enum + from .patch_executor import ( CachedCrossAttentionPatch, DeepCacheUNetExecutorPatch, UiNodeWithIndexPatch, + UNetExtraInputOptions, ) -from .patch_executor import UNetExtraInputOptions class PatchType(Enum): diff --git a/onediff_comfy_nodes/modules/oneflow/__init__.py b/onediff_comfy_nodes/modules/oneflow/__init__.py index 019765f7f..ed52f453b 100644 --- a/onediff_comfy_nodes/modules/oneflow/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/__init__.py @@ -1,7 +1,7 @@ -from .config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS from .booster_basic import BasicOneFlowBoosterExecutor from .booster_deepcache import DeepcacheBoosterExecutor from .booster_patch import PatchBoosterExecutor +from .config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS from .patch_management.patch_for_oneflow import * from .hijack_animatediff import animatediff_hijacker @@ -10,9 +10,9 @@ from .hijack_model_management import model_management_hijacker from .hijack_model_patcher import model_patch_hijacker from .hijack_nodes import nodes_hijacker +from .hijack_pulid_comfyui import pulid_comfyui_hijacker from .hijack_samplers import samplers_hijack from .hijack_utils import comfy_utils_hijack -from .hijack_pulid_comfyui import pulid_comfyui_hijacker model_management_hijacker.hijack() # add flow.cuda.empty_cache() nodes_hijacker.hijack() @@ -22,4 +22,4 @@ comfyui_instantid_hijacker.hijack() model_patch_hijacker.hijack() comfy_utils_hijack.hijack() -pulid_comfyui_hijacker.hijack() \ No newline at end of file +pulid_comfyui_hijacker.hijack() diff --git a/onediff_comfy_nodes/modules/oneflow/booster_basic.py b/onediff_comfy_nodes/modules/oneflow/booster_basic.py index 78a3bc798..d80d6cdbf 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_basic.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_basic.py @@ -13,13 +13,13 @@ from ..booster_interface import BoosterExecutor from .onediff_controlnet import OneDiffControlLora -from .utils.graph_path import generate_graph_path from .utils.booster_utils import ( get_model_type, is_fp16_model, set_compiled_options, set_environment_for_svd_img2vid, ) +from .utils.graph_path import generate_graph_path class BasicOneFlowBoosterExecutor(BoosterExecutor): diff --git a/onediff_comfy_nodes/modules/oneflow/booster_deepcache.py b/onediff_comfy_nodes/modules/oneflow/booster_deepcache.py index 87f8f8fae..6567e4cf8 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_deepcache.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_deepcache.py @@ -5,9 +5,9 @@ from comfy.model_patcher import ModelPatcher from ..booster_interface import BoosterExecutor +from .utils.booster_utils import set_compiled_options from .utils.deep_cache_speedup import deep_cache_speedup from .utils.graph_path import generate_graph_path -from .utils.booster_utils import set_compiled_options @dataclass diff --git a/onediff_comfy_nodes/modules/oneflow/booster_patch.py b/onediff_comfy_nodes/modules/oneflow/booster_patch.py index 42f083d9e..2ecae4991 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_patch.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_patch.py @@ -1,14 +1,14 @@ import os from functools import singledispatchmethod -from comfy.sd import VAE -from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlLora, ControlNet +from comfy.model_patcher import ModelPatcher + +from comfy.sd import VAE from onediff.infer_compiler.backends.oneflow import ( OneflowDeployableModule as DeployableModule, ) - from ..booster_interface import BoosterExecutor diff --git a/onediff_comfy_nodes/modules/oneflow/booster_quantization.py b/onediff_comfy_nodes/modules/oneflow/booster_quantization.py index f4b50d6e4..4b246013c 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_quantization.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_quantization.py @@ -8,23 +8,26 @@ from comfy.controlnet import ControlNet from comfy.model_patcher import ModelPatcher from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow import OneflowDeployableModule as DeployableModule +from onediff.infer_compiler.backends.oneflow import ( + OneflowDeployableModule as DeployableModule, +) +from onediff.optimization import quant_optimizer from onediff_quant.quantization import QuantizationConfig from onediff_quant.quantization.module_operations import get_sub_module from onediff_quant.quantization.quantize_calibrators import ( QuantizationMetricsCalculator, ) from onediff_quant.quantization.quantize_config import Metric -from onediff.optimization import quant_optimizer from .booster_basic import BoosterExecutor -from .utils.graph_path import generate_graph_path +from .patch_management import create_patch_executor, PatchType from .utils.booster_utils import ( is_fp16_model, set_compiled_options, set_environment_for_svd_img2vid, ) -from .patch_management import PatchType, create_patch_executor +from .utils.graph_path import generate_graph_path + class SubQuantizationPercentileCalculator(QuantizationMetricsCalculator): def __init__( @@ -206,7 +209,14 @@ def set_optimized_model(model, quant_config): def _set_optimized_model_for_deepcace(self, model: ModelPatcher): # TODO - print("Warning: DeepCache + OnelineQuantization only support default configurations:") - model.fast_deep_cache_unet.quantize = partial(quant_optimizer.quantize_model, model.fast_deep_cache_unet, quantize_conv=False) - model.deep_cache_unet.quantize = partial(quant_optimizer.quantize_model, model.deep_cache_unet, quantize_conv=False) - + print( + "Warning: DeepCache + OnelineQuantization only support default configurations:" + ) + model.fast_deep_cache_unet.quantize = partial( + quant_optimizer.quantize_model, + model.fast_deep_cache_unet, + quantize_conv=False, + ) + model.deep_cache_unet.quantize = partial( + quant_optimizer.quantize_model, model.deep_cache_unet, quantize_conv=False + ) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/README.md b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/README.md index bff9f1b1b..42fc2267b 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/README.md +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/README.md @@ -8,7 +8,7 @@ Please Refer to the Readme in the Respective Repositories for Installation Instr - ComfyUI: - github: https://github.com/comfyanonymous/ComfyUI - commit: `5d875d77fe6e31a4b0bc6dc36f0441eba3b6afe1` - - Date: `Wed Mar 20 20:48:54 2024 -0400` + - Date: `Wed Mar 20 20:48:54 2024 -0400` - ComfyUI-AnimateDiff-Evolved: - github: https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved @@ -16,16 +16,16 @@ Please Refer to the Readme in the Respective Repositories for Installation Instr - Date: `Wed Mar 20 15:50:08 2024 -0500` - OneDiff: - - github: https://github.com/siliconflow/onediff + - github: https://github.com/siliconflow/onediff ### Quick Start -> Recommend running the official example of ComfyUI AnimateDiff Evolved now, and then trying OneDiff acceleration. +> Recommend running the official example of ComfyUI AnimateDiff Evolved now, and then trying OneDiff acceleration. Experiment (NVIDIA A100-PCIE-40GB) Workflow for OneDiff Acceleration in ComfyUI-AnimateDiff-Evolved: -1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. +1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. 2. Add a **`Batch Size Patcher`** node before the **`Ksampler`** node (due to temporary lack of support for dynamic batch size). As follows: @@ -62,4 +62,3 @@ https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved?tab=readme-ov-file#sa [**`Community and Support`**](https://github.com/siliconflow/onediff?tab=readme-ov-file#community-and-support) - diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/__init__.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/__init__.py index 625c35bfa..01af7bc27 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/__init__.py @@ -4,4 +4,3 @@ from .motion_module_ad import * from .sampling import * from .utils_motion import * - diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/_config.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/_config.py index d6340640f..273ae7b3e 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/_config.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/_config.py @@ -1,6 +1,6 @@ """ github: https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved -commit: 5d875d77fe6e31a4b0bc6dc36f0441eba3b6afe1 +commit: 5d875d77fe6e31a4b0bc6dc36f0441eba3b6afe1 """ import os @@ -21,7 +21,7 @@ else: load_animatediff_package = False except Exception as e: - print(f"Warning: Failed to load {pkg_name} from {animatediff_root} due to {e}") - load_animatediff_package = False + print(f"Warning: Failed to load {pkg_name} from {animatediff_root} due to {e}") + load_animatediff_package = False animatediff_hijacker = Hijacker() diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py index e2cfd49ec..be3f2d62c 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/motion_module_ad.py @@ -20,7 +20,7 @@ class TemporalTransformer3DModel_OF(TemporalTransformer3DModel_OF_CLS): - def get_cameractrl_effect(self, hidden_states: torch.Tensor) : + def get_cameractrl_effect(self, hidden_states: torch.Tensor): # if no raw camera_Ctrl, return None if self.raw_cameractrl_effect is None: return 1.0 @@ -41,7 +41,9 @@ def get_cameractrl_effect(self, hidden_states: torch.Tensor) : self.temp_cameractrl_effect = None # otherwise, calculate temp_cameractrl self.prev_cameractrl_hidden_states_batch = batch - mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width)) + mask = prepare_mask_batch( + self.raw_scale_mask, shape=(self.full_length, 1, height, width) + ) mask = repeat_to_batch_size(mask, self.full_length) # if mask not the same amount length as full length, make it match if self.full_length != mask.shape[0]: @@ -50,7 +52,7 @@ def get_cameractrl_effect(self, hidden_states: torch.Tensor) : batch, channel, height, width = mask.shape # first, perform same operations as on hidden_states, # turning (b, c, h, w) -> (b, h*w, c) - mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel) + mask = mask.permute(0, 2, 3, 1).reshape(batch, height * width, channel) # then, make it the same shape as attention's k, (h*w, b, c) mask = mask.permute(1, 0, 2) # make masks match the expected length of h*w @@ -60,14 +62,22 @@ def get_cameractrl_effect(self, hidden_states: torch.Tensor) : # cache mask and set to proper device self.temp_cameractrl_effect = mask # move temp_cameractrl to proper dtype + device - self.temp_cameractrl_effect = self.temp_cameractrl_effect.to(dtype=hidden_states.dtype, device=hidden_states.device) + self.temp_cameractrl_effect = self.temp_cameractrl_effect.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) # return subset of masks, if needed if self.sub_idxs is not None: return self.temp_cameractrl_effect[:, self.sub_idxs, :] return self.temp_cameractrl_effect - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options=None, mm_kwargs: dict[str]=None): + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + view_options=None, + mm_kwargs: dict[str] = None, + ): batch, channel, height, width = hidden_states.shape residual = hidden_states cameractrl_effect = self.get_cameractrl_effect(hidden_states) @@ -92,7 +102,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None scale_mask=scale_mask, cameractrl_effect=cameractrl_effect, view_options=view_options, - mm_kwargs=mm_kwargs + mm_kwargs=mm_kwargs, ) # output @@ -118,8 +128,8 @@ def forward( attention_mask=None, video_length=None, scale_mask=None, - cameractrl_effect= 1.0, - mm_kwargs: dict[str]={}, + cameractrl_effect=1.0, + mm_kwargs: dict[str] = {}, ): if self.attention_mode != "Temporal": raise NotImplementedError @@ -142,9 +152,16 @@ def forward( if encoder_hidden_states is not None else encoder_hidden_states ) - if self.camera_feature_enabled and self.qkv_merge is not None and mm_kwargs is not None and "camera_feature" in mm_kwargs: + if ( + self.camera_feature_enabled + and self.qkv_merge is not None + and mm_kwargs is not None + and "camera_feature" in mm_kwargs + ): camera_feature: torch.Tensor = mm_kwargs["camera_feature"] - hidden_states = (self.qkv_merge(hidden_states + camera_feature) + hidden_states) * cameractrl_effect + hidden_states * (1. - cameractrl_effect) + hidden_states = ( + self.qkv_merge(hidden_states + camera_feature) + hidden_states + ) * cameractrl_effect + hidden_states * (1.0 - cameractrl_effect) # hidden_states = super().forward( # hidden_states, @@ -184,7 +201,7 @@ def forward( # @torch2oflow.register(TemporalTransformer3DModel_PT_CLS) # def _(mod, verbose=False): -# of_mod = torch2oflow.dispatch(torch_pt.nn.Module)(mod, verbose) +# of_mod = torch2oflow.dispatch(torch_pt.nn.Module)(mod, verbose) # of_mod.video_length = torch.tensor(mod.video_length) # return of_mod diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py index ea201069b..190439976 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_animatediff/sampling.py @@ -1,5 +1,5 @@ # /ComfyUI/custom_nodes/ComfyUI-AnimateDiff-Evolved/animatediff/sampling.py -import oneflow as flow +import oneflow as flow # usort: skip from einops import rearrange from onediff.infer_compiler import DeployableModule from onediff.infer_compiler.backends.oneflow.transform import register @@ -118,9 +118,13 @@ def cond_func(orig_func, self, model, *args, **kwargs): animatediff_hijacker.register( - FunctionInjectionHolder.inject_functions, inject_functions, cond_func, + FunctionInjectionHolder.inject_functions, + inject_functions, + cond_func, ) animatediff_hijacker.register( - FunctionInjectionHolder.restore_functions, restore_functions, cond_func, + FunctionInjectionHolder.restore_functions, + restore_functions, + cond_func, ) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/InstantID.py b/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/InstantID.py index f19728494..05a72b509 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/InstantID.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/InstantID.py @@ -1,9 +1,11 @@ import functools + from onediff.utils.log_utils import logger +from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace_v2 + from ..utils.booster_utils import is_using_oneflow_backend from ._config import comfyui_instantid_hijacker, comfyui_instantid_pt -from ..hijack_ipadapter_plus.set_model_patch_replace import set_model_patch_replace_v2 set_model_patch_replace_fn_pt = comfyui_instantid_pt.InstantID._set_model_patch_replace apply_instantid = comfyui_instantid_pt.InstantID.ApplyInstantID.apply_instantid diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/README.md b/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/README.md index 6cc18c45e..0fe808800 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/README.md +++ b/onediff_comfy_nodes/modules/oneflow/hijack_comfyui_instantid/README.md @@ -14,7 +14,7 @@ git clone https://github.com/comfyanonymous/ComfyUI git reset --hard 2d4164271634476627aae31fbec251ca748a0ae0 ``` When you have completed these steps, follow the [instructions](https://github.com/comfyanonymous/ComfyUI) to install ComfyUI - + #### Install ComfyUI_InstantID ``` @@ -26,12 +26,12 @@ When you have completed these steps,follow the [instructions](https://github.com ### Quick Start -> Recommend running the official example of ComfyUI_InstantID now, and then trying OneDiff acceleration. +> Recommend running the official example of ComfyUI_InstantID now, and then trying OneDiff acceleration. > You can Load these images in ComfyUI to get the full workflow. Experiment (GeForce RTX 3090) Workflow for OneDiff Acceleration in ComfyUI_InstantID: -1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. +1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. 2. Add a **`Batch Size Patcher`** node before the **`Ksampler`** node (due to temporary lack of support for dynamic batch size). As follows: ![workflow (20)](https://github.com/siliconflow/onediff/assets/117806079/492a83a8-1a5b-4fb3-9e53-6d53e881a3f8) @@ -120,4 +120,4 @@ For users of OneDiff Community, please visit [GitHub Issues](https://github.com/ For users of OneDiff Enterprise, you can contact contact@siliconflow.com for commercial support. -Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. \ No newline at end of file +Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py index 49f7e87b0..85efb1d09 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/IPAdapterPlus.py @@ -3,7 +3,9 @@ hijack ComfyUI/custom_nodes/ComfyUI_IPAdapter_plus/IPAdapterPlus.py""" import functools + from onediff.utils.log_utils import logger + from ..utils.booster_utils import is_using_oneflow_backend from ._config import ipadapter_plus_hijacker, ipadapter_plus_pt from .set_model_patch_replace import set_model_patch_replace_v2 diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/README.md b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/README.md index 0e7e8edf3..bb883d889 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/README.md +++ b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/README.md @@ -2,11 +2,11 @@ ### Quick Start -> Recommend running the official example of ComfyUI_IPAdapter_plus now, and then trying OneDiff acceleration. +> Recommend running the official example of ComfyUI_IPAdapter_plus now, and then trying OneDiff acceleration. Experiment (GeForce RTX 3090) Workflow for OneDiff Acceleration in ComfyUI_IPAdapter_plus: -1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. +1. Replace the **`Load Checkpoint`** node with **`Load Checkpoint - OneDiff`** node. 2. Add a **`Batch Size Patcher`** node before the **`Ksampler`** node (due to temporary lack of support for dynamic batch size). As follows: ![workflow (19)](https://github.com/siliconflow/onediff/assets/117806079/07b153fd-a236-4c8d-a220-9b5823a79c17) @@ -78,4 +78,4 @@ For users of OneDiff Community, please visit [GitHub Issues](https://github.com/ For users of OneDiff Enterprise, you can contact contact@siliconflow.com for commercial support. -Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. \ No newline at end of file +Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py index 2310df837..305ec1549 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py @@ -1,10 +1,11 @@ import torch -from register_comfy.CrossAttentionPatch import Attn2Replace, ipadapter_attention from comfy import model_management from onediff.infer_compiler.backends.oneflow.transform import torch2oflow +from register_comfy.CrossAttentionPatch import Attn2Replace, ipadapter_attention + +from ..patch_management import create_patch_executor, PatchType from ..utils.booster_utils import clear_deployable_module_cache_and_unbind -from ..patch_management import PatchType, create_patch_executor def set_model_patch_replace_v2(org_fn, model, patch_kwargs, key): diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_model_management.py b/onediff_comfy_nodes/modules/oneflow/hijack_model_management.py index 7cdc99c1b..0f7a282e8 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_model_management.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_model_management.py @@ -1,5 +1,5 @@ # ComfyUI/comfy/hijack_model_management.py -import oneflow as flow +import oneflow as flow # usort: skip from comfy.model_management import soft_empty_cache from ..sd_hijack_utils import Hijacker diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_model_patcher.py b/onediff_comfy_nodes/modules/oneflow/hijack_model_patcher.py index ae61b7a25..a718d304e 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_model_patcher.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_model_patcher.py @@ -1,11 +1,11 @@ - from comfy.model_patcher import ModelPatcher from ..sd_hijack_utils import Hijacker -from .patch_management import PatchType, create_patch_executor +from .patch_management import create_patch_executor, PatchType from .utils.booster_utils import is_using_oneflow_backend -def clone_oneflow(org_fn, self, *args, **kwargs): + +def clone_oneflow(org_fn, self, *args, **kwargs): n = org_fn(self, *args, **kwargs) create_patch_executor(PatchType.UiNodeWithIndexPatch).copy_to(self, n) dc_patch_executor = create_patch_executor(PatchType.DCUNetExecutorPatch) @@ -13,10 +13,11 @@ def clone_oneflow(org_fn, self, *args, **kwargs): dc_patch_executor.copy_to(self, n) return n + def cond_func(org_fn, self): return is_using_oneflow_backend(self) - + + model_patch_hijacker = Hijacker() model_patch_hijacker.register(ModelPatcher.clone, clone_oneflow, cond_func) - diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/__init__.py b/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/__init__.py index 66ee9c9b6..8da720289 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/__init__.py @@ -1,4 +1,4 @@ -from ._config import pulid_comfyui_hijacker, is_load_pulid_comfyui_pkg +from ._config import is_load_pulid_comfyui_pkg, pulid_comfyui_hijacker if is_load_pulid_comfyui_pkg: from .pulid import * diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/pulid.py b/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/pulid.py index 6d03534a4..b14404d0f 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/pulid.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_pulid_comfyui/pulid.py @@ -1,14 +1,15 @@ -import torch import comfy -import torchvision.transforms as T import comfy.utils +import torch +import torchvision.transforms as T from facexlib.parsing import init_parsing_model from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from ..utils.booster_utils import is_using_oneflow_backend -from ._config import pulid_comfyui_pt, pulid_comfyui_hijacker -from ..hijack_ipadapter_plus.set_model_patch_replace import apply_patch from register_comfy.CrossAttentionPatch import pulid_attention +from ..hijack_ipadapter_plus.set_model_patch_replace import apply_patch +from ..utils.booster_utils import is_using_oneflow_backend +from ._config import pulid_comfyui_hijacker, pulid_comfyui_pt + pulid_pkg = pulid_comfyui_pt.pulid PulidModel = pulid_pkg.PulidModel tensor_to_image = pulid_pkg.tensor_to_image diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py index cdd2eb57a..14fc440b0 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_samplers.py @@ -7,7 +7,7 @@ from comfy.samplers import calc_cond_batch, can_concat_cond, cond_cat, get_area_and_mult from ..sd_hijack_utils import Hijacker -from .patch_management import PatchType, create_patch_executor +from .patch_management import create_patch_executor, PatchType from .utils.booster_utils import is_using_oneflow_backend @@ -149,5 +149,7 @@ def cond_func(orig_func, model, *args, **kwargs): samplers_hijack = Hijacker() samplers_hijack.register( - orig_func=calc_cond_batch, sub_func=calc_cond_batch_of, cond_func=cond_func, + orig_func=calc_cond_batch, + sub_func=calc_cond_batch_of, + cond_func=cond_func, ) diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_utils.py b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py index b2e8034e5..d8af4c2d4 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py @@ -4,6 +4,7 @@ from onediff.infer_compiler.backends.oneflow.param_utils import ( update_graph_related_tensor, ) + from ..sd_hijack_utils import Hijacker diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py index 1cf397050..48bad18cc 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py @@ -3,11 +3,13 @@ https://github.com/cubiq/ComfyUI_IPAdapter_plus/blob/main/CrossAttentionPatch.py """ -import torch import math + +import torch import torch.nn.functional as F from comfy.ldm.modules.attention import attention_pytorch as optimized_attention + def tensor_to_size(source, dest_size): if source.dim() == 0: print("x is a scalar (no dimensions)") @@ -17,29 +19,28 @@ def tensor_to_size(source, dest_size): source_size = source.shape[0] if source_size < dest_size: - shape = [dest_size - source_size] + [1]*(source.dim()-1) + shape = [dest_size - source_size] + [1] * (source.dim() - 1) source = torch.cat((source, source[-1:].repeat(shape)), dim=0) elif source_size > dest_size: source = source[:dest_size] return source -class Attn2Replace: +class Attn2Replace: def __init__(self, callback=None, **kwargs): self.callback = [callback] self.kwargs = [kwargs] self.forward_patch_key = id(self) self._use_crossAttention_patch = True - self.cache_map = {} # {ui_index, index} + self.cache_map = {} # {ui_index, index} self._bind_model = None self.optimized_attention = optimized_attention - def get_bind_model(self): return self._bind_model - - def add(self, callback, **kwargs): + + def add(self, callback, **kwargs): self.callback.append(callback) self.kwargs.append(kwargs) @@ -49,52 +50,87 @@ def add(self, callback, **kwargs): def __call__(self, q, k, v, extra_options): dtype = q.dtype out = self.optimized_attention(q, k, v, extra_options["n_heads"]) - sigma = extra_options["sigmas"] if 'sigmas' in extra_options else 999999999.9 + sigma = extra_options["sigmas"] if "sigmas" in extra_options else 999999999.9 patch_kwargs = extra_options["_attn2"].get(self.forward_patch_key) for i, callback in enumerate(self.callback): - if sigma <= self.kwargs[i]["sigma_start"] and sigma >= self.kwargs[i]["sigma_end"]: - out = out + callback(out, q, k, v, extra_options, optimized_attention=self.optimized_attention, **self.kwargs[i], **patch_kwargs[i]) - + if ( + sigma <= self.kwargs[i]["sigma_start"] + and sigma >= self.kwargs[i]["sigma_end"] + ): + out = out + callback( + out, + q, + k, + v, + extra_options, + optimized_attention=self.optimized_attention, + **self.kwargs[i], + **patch_kwargs[i] + ) + return out.to(dtype=dtype) - + def __deepcopy__(self, memo): # print("Warning: CrossAttentionPatch is not deepcopiable.", '-'*20) return self -def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', optimized_attention=None, **kwargs): + +def ipadapter_attention( + out, + q, + k, + v, + extra_options, + module_key="", + ipadapter=None, + weight=1.0, + cond=None, + cond_alt=None, + uncond=None, + weight_type="linear", + mask=None, + sigma_start=0.0, + sigma_end=1.0, + unfold_batch=False, + embeds_scaling="V only", + optimized_attention=None, + **kwargs +): dtype = q.dtype cond_or_uncond = extra_options["cond_or_uncond"] block_type = extra_options["block"][0] - #block_id = extra_options["block"][1] + # block_id = extra_options["block"][1] t_idx = extra_options["transformer_index"] - layers = 11 if '101_to_k_ip' in ipadapter.ip_layers.to_kvs else 16 + layers = 11 if "101_to_k_ip" in ipadapter.ip_layers.to_kvs else 16 k_key = module_key + "_to_k_ip" v_key = module_key + "_to_v_ip" # extra options for AnimateDiff - ad_params = extra_options['ad_params'] if "ad_params" in extra_options else None + ad_params = extra_options["ad_params"] if "ad_params" in extra_options else None b = q.shape[0] seq_len = q.shape[1] batch_prompt = b // len(cond_or_uncond) _, _, oh, ow = extra_options["original_shape"] - if weight_type == 'ease in': + if weight_type == "ease in": weight = weight * (0.05 + 0.95 * (1 - t_idx / layers)) - elif weight_type == 'ease out': + elif weight_type == "ease out": weight = weight * (0.05 + 0.95 * (t_idx / layers)) - elif weight_type == 'ease in-out': - weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers/2)) / (layers/2))) - elif weight_type == 'reverse in-out': - weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers/2)) / (layers/2))) - elif weight_type == 'weak input' and block_type == 'input': + elif weight_type == "ease in-out": + weight = weight * (0.05 + 0.95 * (1 - abs(t_idx - (layers / 2)) / (layers / 2))) + elif weight_type == "reverse in-out": + weight = weight * (0.05 + 0.95 * (abs(t_idx - (layers / 2)) / (layers / 2))) + elif weight_type == "weak input" and block_type == "input": weight = weight * 0.2 - elif weight_type == 'weak middle' and block_type == 'middle': + elif weight_type == "weak middle" and block_type == "middle": weight = weight * 0.2 - elif weight_type == 'weak output' and block_type == 'output': + elif weight_type == "weak output" and block_type == "output": weight = weight * 0.2 - elif weight_type == 'strong middle' and (block_type == 'input' or block_type == 'output'): + elif weight_type == "strong middle" and ( + block_type == "input" or block_type == "output" + ): weight = weight * 0.2 elif isinstance(weight, dict): if t_idx not in weight: @@ -130,7 +166,9 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No weight = torch.Tensor(weight[ad_params["sub_idxs"]]) # if torch.all(weight == 0): # return 0 - weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond # elif weight == 0: # return 0 @@ -149,7 +187,9 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No weight = tensor_to_size(weight, batch_prompt) # if torch.all(weight == 0): # return 0 - weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond # elif weight == 0: # return 0 @@ -166,7 +206,9 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No weight = tensor_to_size(weight, batch_prompt) # if torch.all(weight == 0): # return 0 - weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond + weight = weight.repeat( + len(cond_or_uncond), 1, 1 + ) # repeat for cond and uncond # elif weight == 0: # return 0 @@ -175,14 +217,14 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No v_cond = ipadapter.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1) v_uncond = ipadapter.ip_layers.to_kvs[v_key](uncond).repeat(batch_prompt, 1, 1) - if len(cond_or_uncond) == 3: # TODO: conxl, I need to check this + if len(cond_or_uncond) == 3: # TODO: conxl, I need to check this ip_k = torch.cat([(k_cond, k_uncond, k_cond)[i] for i in cond_or_uncond], dim=0) ip_v = torch.cat([(v_cond, v_uncond, v_cond)[i] for i in cond_or_uncond], dim=0) else: ip_k = torch.cat([(k_cond, k_uncond)[i] for i in cond_or_uncond], dim=0) ip_v = torch.cat([(v_cond, v_uncond)[i] for i in cond_or_uncond], dim=0) - if embeds_scaling == 'K+mean(V) w/ C penalty': + if embeds_scaling == "K+mean(V) w/ C penalty": scaling = float(ip_k.shape[2]) / 1280.0 weight = weight * scaling ip_k = ip_k * weight @@ -190,20 +232,20 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No ip_v = (ip_v - ip_v_mean) + ip_v_mean * weight out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) del ip_v_mean - elif embeds_scaling == 'K+V w/ C penalty': + elif embeds_scaling == "K+V w/ C penalty": scaling = float(ip_k.shape[2]) / 1280.0 weight = weight * scaling ip_k = ip_k * weight ip_v = ip_v * weight out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) - elif embeds_scaling == 'K+V': + elif embeds_scaling == "K+V": ip_k = ip_k * weight ip_v = ip_v * weight out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) else: - #ip_v = ip_v * weight + # ip_v = ip_v * weight out_ip = optimized_attention(q, ip_k, ip_v, extra_options["n_heads"]) - out_ip = out_ip * weight # I'm doing this to get the same results as before + out_ip = out_ip * weight # I'm doing this to get the same results as before if mask is not None: mask_h = oh / math.sqrt(oh * ow / seq_len) @@ -211,17 +253,27 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No mask_w = seq_len // mask_h # check if using AnimateDiff and sliding context window - if (mask.shape[0] > 1 and ad_params is not None and ad_params["sub_idxs"] is not None): + if ( + mask.shape[0] > 1 + and ad_params is not None + and ad_params["sub_idxs"] is not None + ): # if mask length matches or exceeds full_length, get sub_idx masks if mask.shape[0] >= ad_params["full_length"]: mask = torch.Tensor(mask[ad_params["sub_idxs"]]) - mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) else: - mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) mask = tensor_to_size(mask, ad_params["full_length"]) mask = mask[ad_params["sub_idxs"]] else: - mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) mask = tensor_to_size(mask, batch_prompt) mask = mask.repeat(len(cond_or_uncond), 1, 1) @@ -236,11 +288,11 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) elif mask_len > seq_len: crop_start = (mask_len - seq_len) // 2 - mask = mask[:, crop_start:crop_start+seq_len, :] + mask = mask[:, crop_start : crop_start + seq_len, :] out_ip = out_ip * mask - #out = out + out_ip + # out = out + out_ip return out_ip.to(dtype=dtype) @@ -249,7 +301,23 @@ def is_crossAttention_patch(module) -> bool: return getattr(module, "_use_crossAttention_patch", False) -def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond=None, uncond=None, weight=1.0, ortho=False, ortho_v2=False, mask=None, optimized_attention=None, **kwargs): +def pulid_attention( + out, + q, + k, + v, + extra_options, + module_key="", + pulid=None, + cond=None, + uncond=None, + weight=1.0, + ortho=False, + ortho_v2=False, + mask=None, + optimized_attention=None, + **kwargs +): k_key = module_key + "_to_k_ip" v_key = module_key + "_to_v_ip" @@ -260,11 +328,11 @@ def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond batch_prompt = b // len(cond_or_uncond) _, _, oh, ow = extra_options["original_shape"] - #conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) - #zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) - #conds = torch.cat([conds, zero_tensor], dim=1) - #ip_k = pulid.ip_layers.to_kvs[k_key](conds) - #ip_v = pulid.ip_layers.to_kvs[v_key](conds) + # conds = torch.cat([uncond.repeat(batch_prompt, 1, 1), cond.repeat(batch_prompt, 1, 1)], dim=0) + # zero_tensor = torch.zeros((conds.size(0), num_zero, conds.size(-1)), dtype=conds.dtype, device=conds.device) + # conds = torch.cat([conds, zero_tensor], dim=1) + # ip_k = pulid.ip_layers.to_kvs[k_key](conds) + # ip_v = pulid.ip_layers.to_kvs[v_key](conds) k_cond = pulid.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1) k_uncond = pulid.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1) @@ -278,7 +346,11 @@ def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond if ortho: out = out.to(dtype=torch.float32) out_ip = out_ip.to(dtype=torch.float32) - projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) + projection = ( + torch.sum((out * out_ip), dim=-2, keepdim=True) + / torch.sum((out * out), dim=-2, keepdim=True) + * out + ) orthogonal = out_ip - projection out_ip = weight * orthogonal elif ortho_v2: @@ -287,7 +359,11 @@ def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond attn_map = q @ ip_k.transpose(-2, -1) attn_mean = attn_map.softmax(dim=-1).mean(dim=1, keepdim=True) attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True) - projection = (torch.sum((out * out_ip), dim=-2, keepdim=True) / torch.sum((out * out), dim=-2, keepdim=True) * out) + projection = ( + torch.sum((out * out_ip), dim=-2, keepdim=True) + / torch.sum((out * out), dim=-2, keepdim=True) + * out + ) orthogonal = out_ip + (attn_mean - 1) * projection out_ip = weight * orthogonal else: @@ -298,7 +374,9 @@ def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond mask_h = int(mask_h) + int((seq_len % int(mask_h)) != 0) mask_w = seq_len // mask_h - mask = F.interpolate(mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear").squeeze(1) + mask = F.interpolate( + mask.unsqueeze(1), size=(mask_h, mask_w), mode="bilinear" + ).squeeze(1) mask = tensor_to_size(mask, batch_prompt) mask = mask.repeat(len(cond_or_uncond), 1, 1) @@ -313,7 +391,7 @@ def pulid_attention(out, q, k, v, extra_options, module_key='', pulid=None, cond mask = F.pad(mask, (0, 0, pad1, pad2), value=0.0) elif mask_len > seq_len: crop_start = (mask_len - seq_len) // 2 - mask = mask[:, crop_start:crop_start+seq_len, :] + mask = mask[:, crop_start : crop_start + seq_len, :] out_ip = out_ip * mask diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py index 32b668121..e55ea1488 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/__init__.py @@ -1,18 +1,23 @@ from pathlib import Path import comfy +from comfy.ldm.modules.attention import attention_pytorch from comfy.ldm.modules.diffusionmodules.model import AttnBlock -from comfy.ldm.modules.attention import attention_pytorch + from .attention import attention_pytorch_oneflow from nodes import * # must imported before import comfy from onediff.infer_compiler.backends.oneflow.transform import register -from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version +from onediff.infer_compiler.backends.oneflow.utils.version_util import ( + is_community_version, +) -from .attention import CrossAttention as CrossAttention1f -from .attention import SpatialTransformer as SpatialTransformer1f -from .attention import SpatialVideoTransformer as SpatialVideoTransformer1f +from .attention import ( + CrossAttention as CrossAttention1f, + SpatialTransformer as SpatialTransformer1f, + SpatialVideoTransformer as SpatialVideoTransformer1f, +) from .deep_cache_unet import DeepCacheUNet, FastDeepCacheUNet from .linear import Linear as Linear1f from .util import AlphaBlender as AlphaBlender1f @@ -36,9 +41,11 @@ AttnBlock: AttnBlock1f, } -from .openaimodel import UNetModel as UNetModel1f -from .openaimodel import Upsample as Upsample1f -from .openaimodel import VideoResBlock as VideoResBlock1f +from .openaimodel import ( + UNetModel as UNetModel1f, + Upsample as Upsample1f, + VideoResBlock as VideoResBlock1f, +) torch2of_class_map.update( { diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/attention.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/attention.py index 27bf9165a..895ad4540 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/attention.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/attention.py @@ -392,7 +392,9 @@ def forward( ): transformer_options["block_index"] = it_ x = block( - x, context=spatial_context, transformer_options=transformer_options, + x, + context=spatial_context, + transformer_options=transformer_options, ) x_mix = x @@ -425,7 +427,6 @@ def forward( return out - def attention_pytorch_oneflow(q, k, v, heads, mask=None, attn_precision=None): b, _, dim_head = q.shape dim_head //= heads @@ -442,4 +443,3 @@ def attention_pytorch_oneflow(q, k, v, heads, mask=None, attn_precision=None): causal=False, ) return out - diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/deep_cache_unet.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/deep_cache_unet.py index d4c78b865..518361580 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/deep_cache_unet.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/deep_cache_unet.py @@ -1,6 +1,8 @@ import torch from comfy.ldm.modules.diffusionmodules.openaimodel import ( - apply_control, forward_timestep_embed) + apply_control, + forward_timestep_embed, +) from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from torch.nn import Module @@ -41,7 +43,7 @@ def forward( num_video_frames = c_dict.get( "num_video_frames", self.unet_module.default_num_video_frames ) - + default_image_only_indicator = getattr( self.unet_module, "default_image_only_indicator", None ) diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/util.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/util.py index 345fbaa24..cd80ca1e8 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/util.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/util.py @@ -4,6 +4,7 @@ import oneflow.nn.functional as F from einops import rearrange + class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] @@ -35,7 +36,9 @@ def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), - torch.sigmoid(self.mix_factor.to(image_only_indicator.device)).unsqueeze(-1), + torch.sigmoid( + self.mix_factor.to(image_only_indicator.device) + ).unsqueeze(-1), ) # alpha = rearrange(alpha, self.rearrange_pattern) # Rewrite for onediff SVD dynamic shape, only VideoResBlock, rearrange_pattern="b t -> b 1 t 1 1", @@ -52,7 +55,10 @@ def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor: return alpha def forward( - self, x_spatial, x_temporal, image_only_indicator=None, + self, + x_spatial, + x_temporal, + image_only_indicator=None, ) -> torch.Tensor: alpha = self.get_alpha(image_only_indicator, x_spatial.device) x = ( diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/vae_patch.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/vae_patch.py index 14f1a26d8..0fd72df29 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/vae_patch.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/vae_patch.py @@ -19,7 +19,8 @@ def forward(self, x): B, C, _, _ = x.shape # compute attention q, k, v = map( - lambda t: t.reshape(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), + lambda t: t.reshape(B, 1, C, -1).transpose(2, 3).contiguous(), + (q, k, v), ) _, _, _, head_dim = q.shape diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py index 48f0de2a6..dc3d0ff7f 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_onediff_quant.py @@ -1,5 +1,5 @@ import onediff_quant -import oneflow as flow +import oneflow as flow # usort: skip from onediff.infer_compiler.backends.oneflow.transform import register torch2oflow_class_map = { diff --git a/onediff_comfy_nodes/modules/oneflow/onediff_controlnet.py b/onediff_comfy_nodes/modules/oneflow/onediff_controlnet.py index 583f035b1..25bfb4b36 100644 --- a/onediff_comfy_nodes/modules/oneflow/onediff_controlnet.py +++ b/onediff_comfy_nodes/modules/oneflow/onediff_controlnet.py @@ -1,5 +1,5 @@ import comfy -import oneflow as flow +import oneflow as flow # usort: skip import torch from comfy.controlnet import ControlLora, ControlLoraOps, ControlNet from onediff.infer_compiler import oneflow_compile @@ -35,7 +35,7 @@ def _set_attr_of(obj, attr, value): class OneDiffControlLora(ControlLora): @classmethod def from_controllora( - cls, controlnet: ControlLora, *, gen_compile_options: callable = None + cls, controlnet: ControlLora, *, gen_compile_options: callable = None ): c = cls( controlnet.control_weights, @@ -83,13 +83,12 @@ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): if self.gen_compile_options is not None else {} ) - self._oneflow_model = oneflow_compile( - self.control_model - ) + self._oneflow_model = oneflow_compile(self.control_model) compiled_options = self._oneflow_model._deployable_module_options compiled_options.graph_file = file_device_dict.get("graph_file", None) - compiled_options.graph_file_device = file_device_dict.get("graph_file_device", None) - + compiled_options.graph_file_device = file_device_dict.get( + "graph_file_device", None + ) self.control_model = self._oneflow_model diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_executor.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_executor.py index 2e10b43ab..d02336f94 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_executor.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_executor.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from typing import Dict, List -from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel +from comfy.model_patcher import ModelPatcher -class PatchExecutorBase(ABC): +class PatchExecutorBase(ABC): @abstractmethod def check_patch(self): pass @@ -19,22 +19,23 @@ def set_patch(self): def get_patch(self): pass + class UiNodeWithIndexPatch(PatchExecutorBase): DEFAULT_VALUE = -1 INCREMENT_VALUE = 1 - + def __init__(self) -> None: self.patch_name = type(self).__name__ - - def check_patch(self, module: ModelPatcher)->bool: + + def check_patch(self, module: ModelPatcher) -> bool: return hasattr(module, self.patch_name) - + def set_patch(self, module: ModelPatcher, value: int): setattr(module, self.patch_name, value) - - def get_patch(self, module: ModelPatcher)->int: - return getattr(module, self.patch_name, self.DEFAULT_VALUE) - + + def get_patch(self, module: ModelPatcher) -> int: + return getattr(module, self.patch_name, self.DEFAULT_VALUE) + def copy_to(self, old_model: ModelPatcher, new_model: ModelPatcher): value = self.get_patch(old_model) self.set_patch(new_model, value + self.INCREMENT_VALUE) @@ -46,7 +47,7 @@ def __init__(self) -> None: def check_patch(self, module): return hasattr(module, self.patch_name) - + def set_patch(self, module, value: dict): setattr(module, self.patch_name, value) @@ -54,14 +55,13 @@ def get_patch(self, module) -> Dict[str, any]: if not self.check_patch(module): self.set_patch(module, {}) return getattr(module, self.patch_name) - + def clear_patch(self, module): if self.check_patch(module): self.get_patch(module).clear() class CrossAttentionForwardMasksPatch(PatchExecutorBase): - def __init__(self) -> None: """Will be abandoned""" self.patch_name = "forward_masks" @@ -71,7 +71,7 @@ def check_patch(self, module): def set_patch(self, module, value): raise NotImplementedError() - + def get_patch(self, module) -> Dict: if not self.check_patch(module): setattr(module, self.patch_name, {}) @@ -103,15 +103,13 @@ def copy_to(self, old_model: ModelPatcher, new_model: ModelPatcher): self.set_patch(new_model, values) new_model.model.use_deep_cache_unet = True - def is_use_deep_cache_unet(self, module: BaseModel): return getattr(module, "use_deep_cache_unet", False) + class UNetExtraInputOptions(PatchExecutorBase): def __init__(self) -> None: - """UNetExtraInputOptions - - """ + """UNetExtraInputOptions""" super().__init__() self.patch_name = type(self).__name__ @@ -136,4 +134,4 @@ def get_patch(self, module) -> Dict: def clear_patch(self, module): if self.check_patch(module): - self.get_patch(module).clear() \ No newline at end of file + self.get_patch(module).clear() diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_factory.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_factory.py index 59339a7e7..fbf62a489 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_factory.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_factory.py @@ -1,8 +1,15 @@ from enum import Enum -from .patch_executor import CachedCrossAttentionPatch, DeepCacheUNetExecutorPatch, UiNodeWithIndexPatch -from .patch_executor import CrossAttentionForwardMasksPatch, UNetExtraInputOptions + +from .patch_executor import ( + CachedCrossAttentionPatch, + CrossAttentionForwardMasksPatch, + DeepCacheUNetExecutorPatch, + UiNodeWithIndexPatch, + UNetExtraInputOptions, +) from .quantized_input_patch import QuantizedInputPatch + class PatchType(Enum): CachedCrossAttentionPatch = CachedCrossAttentionPatch UNetExtraInputOptions = UNetExtraInputOptions diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py index 304d905f7..7d3c50812 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py @@ -2,7 +2,7 @@ fix: TypeError: can only concatenate str (not "tuple") to str TODO: fix in oneflow """ -import oneflow as flow +import oneflow as flow # usort: skip from oneflow.framework.args_tree import NamedArg diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/quantized_input_patch.py b/onediff_comfy_nodes/modules/oneflow/patch_management/quantized_input_patch.py index 80a242f2d..98091f2ff 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/quantized_input_patch.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/quantized_input_patch.py @@ -1,6 +1,5 @@ -from register_comfy.CrossAttentionPatch import is_crossAttention_patch - from onediff.infer_compiler.backends.oneflow import online_quantization_utils +from register_comfy.CrossAttentionPatch import is_crossAttention_patch from .patch_executor import PatchExecutorBase diff --git a/onediff_comfy_nodes/modules/oneflow/utils/__init__.py b/onediff_comfy_nodes/modules/oneflow/utils/__init__.py index 528694ab9..ee5052a37 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/__init__.py @@ -45,7 +45,9 @@ def save_graph(deploy_module, prefix: str, device: str, subfolder: str): module_class_name = match.group(1) graph_filename = os.path.join( - OUTPUT_FOLDER, subfolder, f"{prefix}-{device}-{module_class_name}.graph", + OUTPUT_FOLDER, + subfolder, + f"{prefix}-{device}-{module_class_name}.graph", ) if isinstance(deploy_module, DeployableModule): diff --git a/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py b/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py index a70246405..3b9f43a56 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py @@ -5,10 +5,12 @@ from comfy.model_base import BaseModel, SVD_img2vid from comfy.model_patcher import ModelPatcher -from onediff.infer_compiler.backends.oneflow import OneflowDeployableModule as DeployableModule +from onediff.infer_compiler.backends.oneflow import ( + OneflowDeployableModule as DeployableModule, +) from onediff.utils import set_boolean_env_var -from ..patch_management import PatchType, create_patch_executor +from ..patch_management import create_patch_executor, PatchType def set_compiled_options(module: DeployableModule, graph_file="unet"): @@ -101,9 +103,7 @@ def clear_deployable_module_cache_and_unbind( create_patch_executor(PatchType.CachedCrossAttentionPatch).clear_patch( diff_model ) - create_patch_executor(PatchType.UNetExtraInputOptions).clear_patch( - diff_model - ) + create_patch_executor(PatchType.UNetExtraInputOptions).clear_patch(diff_model) elif isinstance(module, DeployableModule): diff_model = module diff_model._clear_old_graph() diff --git a/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py b/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py index ebd9b4c1d..097b642c4 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/deep_cache_speedup.py @@ -4,9 +4,10 @@ from onediff.infer_compiler import oneflow_compile from register_comfy import DeepCacheUNet, FastDeepCacheUNet -from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher from .booster_utils import set_environment_for_svd_img2vid +from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher + def deep_cache_speedup( model, @@ -49,12 +50,12 @@ def deep_cache_speedup( current_step = -1 cache_h = None - _first_run = True + def apply_model(model_function, kwargs): if isinstance(model_patcher.model, SVD_img2vid): set_environment_for_svd_img2vid(model_patcher) - nonlocal current_t, current_step, cache_h , _first_run + nonlocal current_t, current_step, cache_h, _first_run if _first_run: if hasattr(model_patcher.deep_cache_unet, "quantize"): @@ -64,7 +65,6 @@ def apply_model(model_function, kwargs): model_patcher.fast_deep_cache_unet.quantize() _first_run = False - xa = kwargs["input"] t = kwargs["timestep"] c_concat = kwargs["c"].get("c_concat", None) @@ -125,7 +125,13 @@ def apply_model(model_function, kwargs): if is_slow_step: cache_h = None model_output, cache_h = model_patcher.deep_cache_unet( - x, timesteps, context, y, control, transformer_options, **extra_conds, + x, + timesteps, + context, + y, + control, + transformer_options, + **extra_conds, ) else: model_output, cache_h = model_patcher.fast_deep_cache_unet( diff --git a/onediff_comfy_nodes/modules/oneflow/utils/loader_sample_tools.py b/onediff_comfy_nodes/modules/oneflow/utils/loader_sample_tools.py index 34acfe3b0..28e23f0e3 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/loader_sample_tools.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/loader_sample_tools.py @@ -3,8 +3,9 @@ # ComfyUI from comfy import model_management from folder_paths import get_input_directory + # onediff -from onediff.infer_compiler import OneflowCompileOptions, oneflow_compile +from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions from onediff.infer_compiler.backends.oneflow.transform import torch2oflow from onediff.optimization.quant_optimizer import quantize_model diff --git a/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py b/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py index 6441673d6..4feb7660c 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/model_patcher.py @@ -2,7 +2,8 @@ import comfy import torch -from register_comfy import DeepCacheUNet, FastDeepCacheUNet + +from ..infer_compiler_registry.register_comfy import DeepCacheUNet, FastDeepCacheUNet def state_dict_hook(module, state_dict, prefix, local_metadata): @@ -33,9 +34,9 @@ def __init__( graph_device=None, ): from onediff.infer_compiler import ( - OneflowCompileOptions, - oneflow_compile, DeployableModule, + oneflow_compile, + OneflowCompileOptions, ) self.weight_inplace_update = weight_inplace_update @@ -506,9 +507,9 @@ def __init__( gen_compile_options=None, ): from onediff.infer_compiler import ( - OneflowCompileOptions, - oneflow_compile, DeployableModule, + oneflow_compile, + OneflowCompileOptions, ) self.weight_inplace_update = weight_inplace_update @@ -525,16 +526,20 @@ def __init__( self.model.diffusion_model, cache_layer_id, cache_block_id ) if use_graph: - gen_compile_options = gen_compile_options or (lambda x: OneflowCompileOptions()) + gen_compile_options = gen_compile_options or ( + lambda x: OneflowCompileOptions() + ) compile_options = gen_compile_options(self.deep_cache_unet) compile_options.use_graph = use_graph self.deep_cache_unet = oneflow_compile( - self.deep_cache_unet, options=compile_options, + self.deep_cache_unet, + options=compile_options, ) compile_options = gen_compile_options(self.fast_deep_cache_unet) compile_options.use_graph = use_graph self.fast_deep_cache_unet = oneflow_compile( - self.fast_deep_cache_unet, options=compile_options, + self.fast_deep_cache_unet, + options=compile_options, ) self.model._register_state_dict_hook(state_dict_hook) diff --git a/onediff_comfy_nodes/modules/oneflow/utils/onediff_load_utils.py b/onediff_comfy_nodes/modules/oneflow/utils/onediff_load_utils.py index 2c702e383..248a96edc 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/onediff_load_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/onediff_load_utils.py @@ -3,7 +3,7 @@ import folder_paths import torch from comfy import model_management -from onediff.infer_compiler import OneflowCompileOptions, oneflow_compile +from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions from ..config import _USE_UNET_INT8, ONEDIFF_QUANTIZED_OPTIMIZED_MODELS from .graph_path import generate_graph_path @@ -29,7 +29,9 @@ def onediff_load_quant_checkpoint_advanced( load_device = model_management.get_torch_device() diffusion_model = modelpatcher.model.diffusion_model.to(load_device) quant_unet = quantize_unet( - diffusion_model=diffusion_model, inplace=True, calibrate_info=calibrate_info, + diffusion_model=diffusion_model, + inplace=True, + calibrate_info=calibrate_info, ) modelpatcher.model.diffusion_model = quant_unet diff --git a/onediff_comfy_nodes/modules/oneflow/utils/onediff_quant_utils.py b/onediff_comfy_nodes/modules/oneflow/utils/onediff_quant_utils.py index a0c6b3ac2..5cbfc1b7a 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/onediff_quant_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/onediff_quant_utils.py @@ -127,8 +127,7 @@ def _can_use_flash_attn(attn): def _rewrite_attention(attn): - from onediff_quant.models import (DynamicQuantLinearModule, - StaticQuantLinearModule) + from onediff_quant.models import DynamicQuantLinearModule, StaticQuantLinearModule dim_head = attn.to_q.out_features // attn.heads has_bias = attn.to_q.bias is not None @@ -185,7 +184,7 @@ def _rewrite_attention(attn): old_env = os.getenv("ONEFLOW_FUSE_QUANT_TO_MATMUL") os.environ["ONEFLOW_FUSE_QUANT_TO_MATMUL"] = "0" attn.to_qkv = cls(attn.to_qkv, attn.to_q.nbits, calibrate, attn.to_q.name) - attn.scale = dim_head ** -0.5 + attn.scale = dim_head**-0.5 os.environ["ONEFLOW_FUSE_QUANT_TO_MATMUL"] = old_env diff --git a/onediff_comfy_nodes/modules/oneflow/utils/quant_ksampler_tools.py b/onediff_comfy_nodes/modules/oneflow/utils/quant_ksampler_tools.py index a14b15603..5f9469472 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/quant_ksampler_tools.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/quant_ksampler_tools.py @@ -9,12 +9,18 @@ import torch.nn as nn from nodes import KSampler, VAEDecode from onediff.infer_compiler import oneflow_compile + # onediff -from onediff.torch_utils.module_operations import (get_sub_module, modify_sub_module) +from onediff.torch_utils.module_operations import get_sub_module, modify_sub_module from onediff_quant import Quantizer + # onediff_quant -from onediff_quant.utils import (find_quantizable_modules, get_quantize_module, - metric_quantize_costs, symm_quantize) +from onediff_quant.utils import ( + find_quantizable_modules, + get_quantize_module, + metric_quantize_costs, + symm_quantize, +) # onediff_comfy_nodes from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher diff --git a/onediff_comfy_nodes/modules/sd_hijack_utils.py b/onediff_comfy_nodes/modules/sd_hijack_utils.py index bf2eb4ebd..cbdf64184 100644 --- a/onediff_comfy_nodes/modules/sd_hijack_utils.py +++ b/onediff_comfy_nodes/modules/sd_hijack_utils.py @@ -1,9 +1,9 @@ """Hijack utils for stable-diffusion.""" import importlib import inspect +from collections import deque from types import FunctionType from typing import Callable, List, Union -from collections import deque __all__ = ["Hijacker", "hijack_func"] @@ -58,7 +58,9 @@ def hijacked_method(*args, **kwargs): return self(*args, **kwargs) setattr( - resolved_obj, func_path[-1], hijacked_method, + resolved_obj, + func_path[-1], + hijacked_method, ) def unhijack_func(): diff --git a/onediff_comfy_nodes/utils/function_selector.py b/onediff_comfy_nodes/utils/function_selector.py index 39cec2068..fe67edf8b 100644 --- a/onediff_comfy_nodes/utils/function_selector.py +++ b/onediff_comfy_nodes/utils/function_selector.py @@ -49,8 +49,10 @@ def __call__(self, iterable, default_func=None): for pair in iterable: commit_hash, func = pair other_commit_date = self._get_commit_date(commit_hash) - if cur_date > other_commit_date and (sel_date and other_commit_date > sel_date): + if cur_date > other_commit_date and ( + sel_date and other_commit_date > sel_date + ): sel_date = other_commit_date sel_func = func - return sel_func + return sel_func diff --git a/onediff_diffusers_extensions/README.md b/onediff_diffusers_extensions/README.md index a13a8075b..922c2ab2c 100644 --- a/onediff_diffusers_extensions/README.md +++ b/onediff_diffusers_extensions/README.md @@ -23,7 +23,7 @@ OneDiffX is a OneDiff Extension for HF diffusers. It provides some acceleration ## Install and setup -1. Follow the steps [here](https://github.com/siliconflow/onediff?tab=readme-ov-file#install-from-source) to install onediff. +1. Follow the steps [here](https://github.com/siliconflow/onediff?tab=readme-ov-file#install-from-source) to install onediff. 2. Install onediffx by following these steps @@ -145,13 +145,13 @@ prompt = "A photo of a cat. Focus light and create sharp, defined edges." # Warmup for i in range(1): deepcache_output = pipe( - prompt, + prompt, cache_interval=3, cache_layer_id=0, cache_block_id=0, output_type='pil' ).images[0] deepcache_output = pipe( - prompt, + prompt, cache_interval=3, cache_layer_id=0, cache_block_id=0, output_type='pil' ).images[0] @@ -179,13 +179,13 @@ prompt = "a photo of an astronaut on a moon" # Warmup for i in range(1): deepcache_output = pipe( - prompt, + prompt, cache_interval=3, cache_layer_id=0, cache_block_id=0, output_type='pil' ).images[0] deepcache_output = pipe( - prompt, + prompt, cache_interval=3, cache_layer_id=0, cache_block_id=0, output_type='pil' ).images[0] @@ -217,13 +217,13 @@ input_image = input_image.resize((1024, 576)) # Warmup for i in range(1): deepcache_output = pipe( - input_image, + input_image, decode_chunk_size=5, cache_interval=3, cache_branch=0, ).frames[0] deepcache_output = pipe( - input_image, + input_image, decode_chunk_size=5, cache_interval=3, cache_branch=0, ).frames[0] diff --git a/onediff_diffusers_extensions/examples/image_to_image.py b/onediff_diffusers_extensions/examples/image_to_image.py index 909ca7db8..4e88673aa 100644 --- a/onediff_diffusers_extensions/examples/image_to_image.py +++ b/onediff_diffusers_extensions/examples/image_to_image.py @@ -1,11 +1,11 @@ import argparse -from PIL import Image import torch -import oneflow as flow +from PIL import Image +import oneflow as flow # usort: skip -from onediff.infer_compiler import oneflow_compile from diffusers import StableDiffusionImg2ImgPipeline +from onediff.infer_compiler import oneflow_compile prompt = "sea,beach,the waves crashed on the sand,blue sky whit white cloud" @@ -14,7 +14,9 @@ def parse_args(): parser = argparse.ArgumentParser(description="Simple demo of image generation.") parser.add_argument( - "--model_id", type=str, default="stabilityai/stable-diffusion-2-1", + "--model_id", + type=str, + default="stabilityai/stable-diffusion-2-1", ) cmd_args = parser.parse_args() return cmd_args @@ -24,7 +26,10 @@ def parse_args(): pipe = StableDiffusionImg2ImgPipeline.from_pretrained( - args.model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16, + args.model_id, + use_auth_token=True, + revision="fp16", + torch_dtype=torch.float16, ) pipe = pipe.to("cuda") @@ -36,7 +41,11 @@ def parse_args(): with flow.autocast("cuda"): images = pipe( - prompt, image=img, guidance_scale=10, num_inference_steps=100, output_type="np", + prompt, + image=img, + guidance_scale=10, + num_inference_steps=100, + output_type="np", ).images for i, image in enumerate(images): pipe.numpy_to_pil(image)[0].save(f"{prompt}-of-{i}.png") diff --git a/onediff_diffusers_extensions/examples/image_to_image_controlnet.py b/onediff_diffusers_extensions/examples/image_to_image_controlnet.py index 3253bf973..8c8693953 100644 --- a/onediff_diffusers_extensions/examples/image_to_image_controlnet.py +++ b/onediff_diffusers_extensions/examples/image_to_image_controlnet.py @@ -1,16 +1,16 @@ import argparse +import cv2 +import numpy as np +import torch + from diffusers import ( - StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, + StableDiffusionControlNetImg2ImgPipeline, UniPCMultistepScheduler, ) from diffusers.utils import load_image -import numpy as np -import torch - -import cv2 from PIL import Image parser = argparse.ArgumentParser() @@ -24,7 +24,9 @@ default="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", ) parser.add_argument( - "--prompt", type=str, default="chinese painting style women", + "--prompt", + type=str, + default="chinese painting style women", ) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) @@ -116,9 +118,10 @@ ).images print("Run") -from tqdm import tqdm import time +from tqdm import tqdm + for i in tqdm(range(args.run), desc="Pipe processing", unit="i"): start_t = time.time() image = pipe( diff --git a/onediff_diffusers_extensions/examples/image_to_image_graph_load.py b/onediff_diffusers_extensions/examples/image_to_image_graph_load.py index ebdc5de2f..c5287c846 100644 --- a/onediff_diffusers_extensions/examples/image_to_image_graph_load.py +++ b/onediff_diffusers_extensions/examples/image_to_image_graph_load.py @@ -1,23 +1,22 @@ """ image to image graph load ,a old example""" -import time -import os import gc +import os import shutil -import unittest import tempfile -from PIL import Image +import time +import unittest import numpy as np -import oneflow as flow +from PIL import Image +import oneflow as flow # usort: skip import oneflow as torch +from diffusers import EulerDiscreteScheduler, utils + from onediff import ( OneFlowStableDiffusionImg2ImgPipeline as StableDiffusionImg2ImgPipeline, ) -from diffusers import EulerDiscreteScheduler -from diffusers import utils - from onediff.infer_compiler.backends.oneflow.utils.cost_util import cost_cnt diff --git a/onediff_diffusers_extensions/examples/latte/README.md b/onediff_diffusers_extensions/examples/latte/README.md index f1c46d5ef..1353cd844 100644 --- a/onediff_diffusers_extensions/examples/latte/README.md +++ b/onediff_diffusers_extensions/examples/latte/README.md @@ -75,7 +75,7 @@ python3 ./benchmarks/text_to_video_latte.py \ 1 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz. Note this is just for reference, and it varies a lot on different CPU. #### nexfort compile config and warmup cost -- compiler-config +- compiler-config - setting `--compiler-config '{"mode": "max-optimize:max-autotune:freezing:benchmark:low-precision", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}` will help to make the best performance but the compilation time is about 572 seconds - setting `--compiler-config '{"mode": "max-autotune", "memory_format": "channels_last", "options": {"inductor.optimize_linear_epilogue": false, "triton.fuse_attention_allow_fp16_reduction": false}}` will reduce compilation time to about 236 seconds and just slightly reduce the performance - fuse_qkv_projections: True diff --git a/onediff_diffusers_extensions/examples/notebooks/text_to_image.ipynb b/onediff_diffusers_extensions/examples/notebooks/text_to_image.ipynb index abd1d9780..63856eebc 100644 --- a/onediff_diffusers_extensions/examples/notebooks/text_to_image.ipynb +++ b/onediff_diffusers_extensions/examples/notebooks/text_to_image.ipynb @@ -10,7 +10,7 @@ "from onediff.infer_compiler import oneflow_compile\n", "from onediff.schedulers import EulerDiscreteScheduler\n", "from diffusers import StableDiffusionPipeline\n", - "import oneflow as flow\n", + "import oneflow as flow # usort: skip\n", "import torch" ] }, diff --git a/onediff_diffusers_extensions/examples/pipe_compile_save_load.py b/onediff_diffusers_extensions/examples/pipe_compile_save_load.py index 28ea0b223..dfc9e0463 100644 --- a/onediff_diffusers_extensions/examples/pipe_compile_save_load.py +++ b/onediff_diffusers_extensions/examples/pipe_compile_save_load.py @@ -4,7 +4,7 @@ import torch from diffusers import StableDiffusionXLPipeline -from onediffx import compile_pipe, save_pipe, load_pipe +from onediffx import compile_pipe, load_pipe, save_pipe parser = argparse.ArgumentParser() parser.add_argument( @@ -18,7 +18,7 @@ "/share_nfs/hf_models/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", - use_safetensors=True + use_safetensors=True, ) pipe.to("cuda") diff --git a/onediff_diffusers_extensions/examples/pixart/README.md b/onediff_diffusers_extensions/examples/pixart/README.md index 25e61f1d0..2756414f0 100644 --- a/onediff_diffusers_extensions/examples/pixart/README.md +++ b/onediff_diffusers_extensions/examples/pixart/README.md @@ -25,7 +25,7 @@ https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/back ### Set up PixArt -HF model: +HF model: - PixArt-alpha: https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS - PixArt-sigma: https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS @@ -133,7 +133,7 @@ python3 ./benchmarks/text_to_image.py \ #### The nexfort backend compile config and warmup cost -- compiler-config +- compiler-config - default is `{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}` in `/benchmarks/text_to_image.py`. This mode supports dynamic shapes. - setting `--compiler-config '{"mode": "max-autotune", "memory_format": "channels_last"}'` will reduce compilation time and just slightly reduce the performance. - setting `--compiler-config '{"mode": "max-optimize:max-autotune:freezing:benchmark:low-precision:cudagraphs", "memory_format": "channels_last"}'` will help achieve the best performance, but it increases the compilation time and affects stability. diff --git a/onediff_diffusers_extensions/examples/reuse_compiled_pipeline_components.py b/onediff_diffusers_extensions/examples/reuse_compiled_pipeline_components.py index 8acd4b000..c4d09deb6 100644 --- a/onediff_diffusers_extensions/examples/reuse_compiled_pipeline_components.py +++ b/onediff_diffusers_extensions/examples/reuse_compiled_pipeline_components.py @@ -2,17 +2,18 @@ This example shows how to reuse the compiled components of a pipeline to create new pipelines. Usage: - $ python reuse_compiled_pipeline_components.py --model_id + $ python reuse_compiled_pipeline_components.py --model_id """ -import PIL import argparse +from io import BytesIO + +import PIL import requests import torch -from io import BytesIO from diffusers import ( - StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionPipeline, ) from onediffx import compile_pipe @@ -20,7 +21,9 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_id", type=str, default="runwayml/stable-diffusion-v1-5", + "--model_id", + type=str, + default="runwayml/stable-diffusion-v1-5", ) return parser.parse_args() diff --git a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py index ed5867835..c571bd212 100644 --- a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py +++ b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py @@ -42,7 +42,10 @@ def parse_args(): "--width", type=int, default=1024, help="Width of the generated image." ) parser.add_argument( - "--guidance_scale", type=float, default=4.5, help="The scale factor for the guidance." + "--guidance_scale", + type=float, + default=4.5, + help="The scale factor for the guidance.", ) parser.add_argument( "--num-inference-steps", type=int, default=28, help="Number of inference steps." diff --git a/onediff_diffusers_extensions/examples/text_to_image.py b/onediff_diffusers_extensions/examples/text_to_image.py index ffa2ba642..a26ff983f 100644 --- a/onediff_diffusers_extensions/examples/text_to_image.py +++ b/onediff_diffusers_extensions/examples/text_to_image.py @@ -2,12 +2,13 @@ example: python examples/text_to_image.py --height 512 --width 512 --warmup 10 --model_id xx """ import argparse + import torch -import oneflow as flow +import oneflow as flow # usort: skip +from diffusers import StableDiffusionPipeline from onediff.infer_compiler import oneflow_compile from onediff.schedulers import EulerDiscreteScheduler -from diffusers import StableDiffusionPipeline def parse_args(): @@ -16,7 +17,9 @@ def parse_args(): "--prompt", type=str, default="a photo of an astronaut riding a horse on mars" ) parser.add_argument( - "--model_id", type=str, default="runwayml/stable-diffusion-v1-5", + "--model_id", + type=str, + default="runwayml/stable-diffusion-v1-5", ) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) diff --git a/onediff_diffusers_extensions/examples/text_to_image_controlnet.py b/onediff_diffusers_extensions/examples/text_to_image_controlnet.py index 34629e870..d95851e9a 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_controlnet.py +++ b/onediff_diffusers_extensions/examples/text_to_image_controlnet.py @@ -1,15 +1,15 @@ import argparse +import cv2 +import numpy as np +import torch + from diffusers import ( - StableDiffusionControlNetPipeline, ControlNetModel, + StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from diffusers.utils import load_image -import numpy as np -import torch - -import cv2 from PIL import Image parser = argparse.ArgumentParser() @@ -21,7 +21,9 @@ default="https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png", ) parser.add_argument( - "--prompt", type=str, default="chinese painting style women", + "--prompt", + type=str, + default="chinese painting style women", ) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) diff --git a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd.py b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd.py index 283329bb9..616753725 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd.py +++ b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd.py @@ -2,8 +2,8 @@ Torch run example: python examples/text_to_image_deep_cache_sd.py --compile 0 Compile to oneflow graph example: python examples/text_to_image_deep_cache_sd.py """ -import os import argparse +import os import torch @@ -27,7 +27,9 @@ parser.add_argument("--warmup", type=int, default=1) parser.add_argument("--seed", type=int, default=1) parser.add_argument( - "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True, + "--compile", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, ) parser.add_argument( "--use_multiple_resolutions", @@ -60,7 +62,10 @@ # Define multiple resolutions for warmup resolutions = ( - [(512, 512), (256, 256),] + [ + (512, 512), + (256, 256), + ] if args.use_multiple_resolutions else [(args.height, args.width)] ) diff --git a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py index 2081fbaa6..45e7d2ea1 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sd_sdxl_enterprise.py @@ -97,17 +97,28 @@ def parse_args(): if args.model_type == "sdxl": pipe = StableDiffusionXLPipeline.from_pretrained( - args.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", + args.model, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", ) else: pipe = StableDiffusionPipeline.from_pretrained( - args.model, revision="fp16", variant="fp16", torch_dtype=torch.float16, + args.model, + revision="fp16", + variant="fp16", + torch_dtype=torch.float16, ) pipe.to("cuda") for sub_module_name, sub_calibrate_info in calibrate_info.items(): replace_sub_module_with_quantizable_module( - pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, + pipe.unet, + sub_module_name, + sub_calibrate_info, + False, + False, + args.bits, ) compile_options = OneflowCompileOptions() diff --git a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py index 7839be91f..c19dfae8f 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_deep_cache_sdxl.py @@ -2,8 +2,8 @@ Torch run example: python examples/text_to_image_deep_cache_sdxl.py --compile 0 Compile to oneflow graph example: python examples/text_to_image_deep_cache_sdxl.py """ -import os import argparse +import os import torch @@ -27,7 +27,9 @@ parser.add_argument("--warmup", type=int, default=1) parser.add_argument("--seed", type=int, default=1) parser.add_argument( - "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True, + "--compile", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, ) parser.add_argument( "--run_multiple_resolutions", @@ -41,7 +43,10 @@ # SDXL base: StableDiffusionXLPipeline base = StableDiffusionXLPipeline.from_pretrained( - args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, + args.base, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, ) base.to("cuda") diff --git a/onediff_diffusers_extensions/examples/text_to_image_lcm.py b/onediff_diffusers_extensions/examples/text_to_image_lcm.py index 42eeb516e..ee7b1d81d 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_lcm.py +++ b/onediff_diffusers_extensions/examples/text_to_image_lcm.py @@ -1,8 +1,8 @@ import argparse -from packaging import version import importlib.metadata from diffusers import DiffusionPipeline +from packaging import version def check_diffusers_version(): diff --git a/onediff_diffusers_extensions/examples/text_to_image_lcm_lora_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_lcm_lora_sdxl.py index 30a120f08..b8cdc76e9 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_lcm_lora_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_lcm_lora_sdxl.py @@ -1,8 +1,8 @@ +import argparse +import importlib.metadata import os -import argparse from packaging import version -import importlib.metadata def parse_args(): @@ -51,7 +51,7 @@ def parse_args(): return args -from diffusers import LCMScheduler, AutoPipelineForText2Image +from diffusers import AutoPipelineForText2Image, LCMScheduler args = parse_args() diff --git a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py index 7125be56d..32f5079dd 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py +++ b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py @@ -2,7 +2,7 @@ ## Performance Comparison -Updated on Mon 08 Apr 2024 +Updated on Mon 08 Apr 2024 Timings for 30 steps at 1024x1024 | Accelerator | Baseline (non-optimized) | OneDiff(optimized) | OneDiff Quant(optimized) | @@ -58,18 +58,24 @@ 2. The log *.pt file is cached. Quantization result information can be found in `cache_dir`/quantization_stats.json. """ -import argparse +import argparse import time -import torch + +import torch from diffusers import AutoPipelineForText2Image -from onediffx import compile_pipe from onediff_quant.quantization import QuantizationConfig +from onediffx import compile_pipe + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_id", default="runwayml/stable-diffusion-v1-5") - parser.add_argument("--prompt", default="a photo of an astronaut riding a horse on mars") - parser.add_argument("--output_file", default="astronaut_rides_horse_onediff_quant.png") + parser.add_argument( + "--prompt", default="a photo of an astronaut riding a horse on mars" + ) + parser.add_argument( + "--output_file", default="astronaut_rides_horse_onediff_quant.png" + ) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--backend", default="onediff", choices=["onediff", "torch"]) parser.add_argument("--quantize", action="store_true") @@ -83,31 +89,45 @@ def parse_args(): parser.add_argument("--linear_compute_density_threshold", type=int, default=300) return parser.parse_args() + def load_model(model_id): - pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16") + pipe = AutoPipelineForText2Image.from_pretrained( + model_id, torch_dtype=torch.float16, variant="fp16" + ) pipe.to(f"cuda") return pipe + def compile_and_quantize_model(pipe, cache_dir, quantize, quant_params): pipe = compile_pipe(pipe) if quantize: - config = QuantizationConfig.from_settings(**quant_params, cache_dir=cache_dir, plot_calibrate_info=True) + config = QuantizationConfig.from_settings( + **quant_params, cache_dir=cache_dir, plot_calibrate_info=True + ) pipe.unet.apply_online_quant(quant_config=config) return pipe + def save_image(image, output_file): image.save(output_file) print(f"Image saved to: {output_file}") + def main(): args = parse_args() pipe = load_model(args.model_id) - if args.backend == "onediff": - compile_and_quantize_model(pipe, args.cache_dir, args.quantize, - {"conv_mae_threshold": args.conv_mae_threshold, - "linear_mae_threshold": args.linear_mae_threshold, - "conv_compute_density_threshold": args.conv_compute_density_threshold, - "linear_compute_density_threshold": args.linear_compute_density_threshold}) + if args.backend == "onediff": + compile_and_quantize_model( + pipe, + args.cache_dir, + args.quantize, + { + "conv_mae_threshold": args.conv_mae_threshold, + "linear_mae_threshold": args.linear_mae_threshold, + "conv_compute_density_threshold": args.conv_compute_density_threshold, + "linear_compute_density_threshold": args.linear_compute_density_threshold, + }, + ) torch.manual_seed(args.seed) # Warm-up pipe(prompt=args.prompt, num_inference_steps=1) @@ -116,11 +136,17 @@ def main(): for _ in range(5): start_time = time.time() torch.manual_seed(args.seed) - image = pipe(prompt=args.prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps).images[0] + image = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + ).images[0] end_time = time.time() print(f"Inference time: {end_time - start_time:.2f} seconds") - + save_image(image, args.output_file) + if __name__ == "__main__": main() diff --git a/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py index e42b47071..aa15f0d4a 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sd_enterprise.py @@ -1,12 +1,12 @@ +import argparse import os import time -import argparse - -from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions import torch import torch.nn as nn +from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions + def parse_args(): parser = argparse.ArgumentParser() @@ -16,7 +16,9 @@ def parse_args(): parser.add_argument("--save_graph", action="store_true") parser.add_argument("--load_graph", action="store_true") parser.add_argument( - "--prompt", type=str, default="a photo of an astronaut riding a horse on mars", + "--prompt", + type=str, + default="a photo of an astronaut riding a horse on mars", ) parser.add_argument("--height", type=int, default=512) parser.add_argument("--width", type=int, default=512) @@ -54,8 +56,8 @@ def parse_args(): os.path.join(args.model, "calibrate_info.txt") ), f"calibrate_info.txt is required in args.model ({args.model})" -from diffusers import StableDiffusionPipeline import onediff_quant +from diffusers import StableDiffusionPipeline from onediff_quant.utils import replace_sub_module_with_quantizable_module onediff_quant.enable_load_quantized_model() @@ -89,7 +91,12 @@ def parse_args(): for sub_module_name, sub_calibrate_info in calibrate_info.items(): replace_sub_module_with_quantizable_module( - pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, + pipe.unet, + sub_module_name, + sub_calibrate_info, + False, + False, + args.bits, ) compile_options = OneflowCompileOptions() diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py index 2ac606c94..ebbc28cd7 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl.py @@ -5,18 +5,19 @@ Test dynamic shape: Add --run_multiple_resolutions 1 and --run_rare_resolutions 1 """ -import os +import argparse import json +import os import time -import argparse import torch -import oneflow as flow +import oneflow as flow # usort: skip + +from diffusers import StableDiffusionXLPipeline # from onediff.infer_compiler import oneflow_compile from onediff.schedulers import EulerDiscreteScheduler from onediffx import compile_pipe -from diffusers import StableDiffusionXLPipeline parser = argparse.ArgumentParser() parser.add_argument( diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py index 5a164d239..2c53aa3eb 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_enterprise.py @@ -57,8 +57,8 @@ def parse_args(): os.path.join(args.model, "calibrate_info.txt") ), f"calibrate_info.txt is required in args.model ({args.model})" -from diffusers import StableDiffusionXLPipeline import onediff_quant +from diffusers import StableDiffusionXLPipeline from onediff_quant.utils import replace_sub_module_with_quantizable_module onediff_quant.enable_load_quantized_model() @@ -80,14 +80,22 @@ def parse_args(): ] pipe = StableDiffusionXLPipeline.from_pretrained( - args.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", + args.model, + torch_dtype=torch.float16, + use_safetensors=True, + variant="fp16", ) pipe.to("cuda") for sub_module_name, sub_calibrate_info in calibrate_info.items(): replace_sub_module_with_quantizable_module( - pipe.unet, sub_module_name, sub_calibrate_info, False, False, args.bits, + pipe.unet, + sub_module_name, + sub_calibrate_info, + False, + False, + args.bits, ) compile_options = OneflowCompileOptions() diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py index 5f2ffa313..d88b1f074 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py @@ -1,12 +1,12 @@ -import os import argparse +import os import time import torch -from safetensors.torch import load_file from diffusers import StableDiffusionXLPipeline -from onediffx import compile_pipe, save_pipe, load_pipe from huggingface_hub import hf_hub_download +from onediffx import compile_pipe, load_pipe, save_pipe +from safetensors.torch import load_file try: USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND @@ -37,7 +37,9 @@ ) parser.add_argument("--seed", type=int, default=1) parser.add_argument( - "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True, + "--compile", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=True, ) @@ -93,7 +95,9 @@ # Compile the pipeline if args.compile: - pipe = compile_pipe(pipe,) + pipe = compile_pipe( + pipe, + ) if args.load_graph: print("Loading graphs...") load_pipe(pipe, args.load_graph_dir) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_lora.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_lora.py index 06d16c81f..84b5bab4c 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_lora.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_lora.py @@ -1,11 +1,16 @@ -import torch from pathlib import Path + +import torch from diffusers import DiffusionPipeline from onediff.infer_compiler import oneflow_compile from onediff.torch_utils import TensorInplaceAssign try: - from onediffx.lora import load_and_fuse_lora, unfuse_lora, update_graph_with_constant_folding_info + from onediffx.lora import ( + load_and_fuse_lora, + unfuse_lora, + update_graph_with_constant_folding_info, + ) except ImportError: raise RuntimeError( "OneDiff onediffx is not installed. Please check onediff_diffusers_extensions/README.md to install onediffx." @@ -19,7 +24,15 @@ LORA_FILENAME = "sd_xl_offset_example-lora_1.0.safetensors" pipe.unet = oneflow_compile(pipe.unet) -latents = torch.randn(1, 4, 128, 128, generator=torch.cuda.manual_seed(0), dtype=torch.float16, device="cuda") +latents = torch.randn( + 1, + 4, + 128, + 128, + generator=torch.cuda.manual_seed(0), + dtype=torch.float16, + device="cuda", +) # There are three methods to load LoRA into OneDiff compiled model # 1. pipe.load_lora_weights (Low Performence) diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_mp_load.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_mp_load.py index 0d4bbc369..a61e71454 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_mp_load.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_mp_load.py @@ -1,11 +1,11 @@ # Compile and save to oneflow graph example: python examples/text_to_image_sdxl_mp_load.py --save # Compile and load to new device example: python examples/text_to_image_sdxl_mp_load.py --load -import os import argparse +import os import torch -import oneflow as flow +import oneflow as flow # usort: skip parser = argparse.ArgumentParser() parser.add_argument( diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py index 7a2498b0a..a50bd2d1c 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_reuse_pipe.py @@ -1,10 +1,10 @@ -import os import argparse +import os import torch +from diffusers import StableDiffusionXLPipeline from onediff.infer_compiler import oneflow_compile -from diffusers import StableDiffusionXLPipeline # import diffusers # diffusers.logging.set_verbosity_info() @@ -14,7 +14,9 @@ "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" ) parser.add_argument( - "--new_base", type=str, default="dataautogpt3/OpenDalleV1.1", + "--new_base", + type=str, + default="dataautogpt3/OpenDalleV1.1", ) parser.add_argument("--variant", type=str, default="fp16") parser.add_argument( @@ -50,7 +52,10 @@ # SDXL base: StableDiffusionXLPipeline base = StableDiffusionXLPipeline.from_pretrained( - args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, + args.base, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, ) base.to("cuda") @@ -142,8 +147,8 @@ image[0].save(f"new_base_reuse_graph_h{args.height}-w{args.width}-{args.saved_image}") image_graph = image[0] -from skimage.metrics import structural_similarity import numpy as np +from skimage.metrics import structural_similarity ssim = structural_similarity( np.array(image_eager), np.array(image_graph), channel_axis=-1, data_range=255 diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py index 0da27858f..1513ded67 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_save_load.py @@ -1,14 +1,14 @@ # Compile and save to oneflow graph example: python examples/text_to_image_sdxl_save_load.py --save # Compile and load to oneflow graph example: python examples/text_to_image_sdxl_save_load.py --load -import os import argparse +import os import torch -import oneflow as flow +import oneflow as flow # usort: skip -from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions from diffusers import DiffusionPipeline +from onediff.infer_compiler import oneflow_compile, OneflowCompileOptions parser = argparse.ArgumentParser() parser.add_argument( diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_turbo.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_turbo.py index 5afe8520d..88d825179 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_turbo.py +++ b/onediff_diffusers_extensions/examples/text_to_image_sdxl_turbo.py @@ -2,15 +2,15 @@ Torch run example: python examples/text_to_image_sdxl_turbo.py Compile to oneflow graph example: python examples/text_to_image_sdxl_turbo.py --compile """ +import argparse import os import time -import argparse import torch -import oneflow as flow +import oneflow as flow # usort: skip -from onediff.infer_compiler import oneflow_compile from diffusers import AutoPipelineForText2Image +from onediff.infer_compiler import oneflow_compile parser = argparse.ArgumentParser() parser.add_argument("--base", type=str, default="stabilityai/sdxl-turbo") @@ -36,7 +36,10 @@ # SDXL turbo base: AutoPipelineForText2Image base = AutoPipelineForText2Image.from_pretrained( - args.base, torch_dtype=torch.float16, variant=args.variant, use_safetensors=True, + args.base, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, ) base.to("cuda") diff --git a/onediff_diffusers_extensions/examples/unet_torch_interplay.py b/onediff_diffusers_extensions/examples/unet_torch_interplay.py index 1816845ee..b7fcee927 100644 --- a/onediff_diffusers_extensions/examples/unet_torch_interplay.py +++ b/onediff_diffusers_extensions/examples/unet_torch_interplay.py @@ -3,18 +3,20 @@ save graph compiled example: python3 examples/unet_torch_interplay.py --save --model_id xx load graph compiled example: python3 examples/unet_torch_interplay.py --load """ -import os import importlib.metadata -from packaging import version +import os import random + import click import torch -import oneflow as flow +from packaging import version +import oneflow as flow # usort: skip -from tqdm import tqdm from dataclasses import dataclass, fields + from onediff.infer_compiler import oneflow_compile +from tqdm import tqdm @dataclass diff --git a/onediff_diffusers_extensions/onediffx/__init__.py b/onediff_diffusers_extensions/onediffx/__init__.py index a74eef9b4..a60d7ec56 100644 --- a/onediff_diffusers_extensions/onediffx/__init__.py +++ b/onediff_diffusers_extensions/onediffx/__init__.py @@ -1,11 +1,12 @@ __version__ = "1.2.0.dev1" +from onediff.infer_compiler import OneflowCompileOptions + from .compilers.diffusion_pipeline_compiler import ( compile_pipe, - save_pipe, load_pipe, quantize_pipe, + save_pipe, ) -from onediff.infer_compiler import OneflowCompileOptions __all__ = [ "compile_pipe", diff --git a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py index 192820c69..e7907f37f 100644 --- a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py +++ b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py @@ -1,6 +1,6 @@ -import os -import json import functools +import json +import os import torch from onediff.infer_compiler import compile, DeployableModule @@ -115,9 +115,10 @@ def fuse_qkv_projections_in_pipe(pipe): def convert_pipe_to_memory_format( pipe, *, ignores=(), memory_format=torch.preserve_format ): + import functools + from nexfort.utils.attributes import multi_recursive_apply from nexfort.utils.memory_format import apply_memory_format - import functools if memory_format == torch.preserve_format: return pipe diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/README.md b/onediff_diffusers_extensions/onediffx/deep_cache/README.md index 5ff380730..ef379175a 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/README.md +++ b/onediff_diffusers_extensions/onediffx/deep_cache/README.md @@ -10,4 +10,4 @@ DeepCache originality is here https://github.com/horseee/DeepCache journal={arXiv preprint arXiv:2312.00858}, year={2023} } -``` \ No newline at end of file +``` diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/__init__.py b/onediff_diffusers_extensions/onediffx/deep_cache/__init__.py index e5c398abf..01406d744 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/__init__.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/__init__.py @@ -1,7 +1,8 @@ -from packaging import version import importlib import importlib.metadata +from packaging import version + diffusers_0193_v = version.parse("0.19.3") diffusers_0240_v = version.parse("0.24.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) @@ -12,9 +13,9 @@ ) from .models.pipeline_utils import disable_deep_cache_pipeline +from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline -from .pipeline_stable_diffusion import StableDiffusionPipeline if diffusers_version >= diffusers_0240_v: from .pipeline_stable_video_diffusion import StableVideoDiffusionPipeline diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py index 9b2eda306..d7adad6ba 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_2d_condition.py @@ -1,10 +1,11 @@ +import importlib.metadata +from typing import Any, Dict, List, Optional, Tuple, Union + import torch import torch.nn as nn -from typing import Union, Optional, Dict, Any, Tuple, List +from oneflow.nn.graph.proxy import ProxyModule from packaging import version -import importlib.metadata -from oneflow.nn.graph.proxy import ProxyModule diffusers_0210_v = version.parse("0.21.0") diffusers_0270_v = version.parse("0.27.0") @@ -12,8 +13,7 @@ from diffusers.utils import BaseOutput, logging -from .unet_2d_condition import UNet2DConditionModel -from .unet_2d_condition import UNet2DConditionOutput +from .unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput try: USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND @@ -81,7 +81,7 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - default_overall_up_factor = 2 ** self.unet_module.num_upsamplers + default_overall_up_factor = 2**self.unet_module.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` # forward_upsample_size = False @@ -238,7 +238,9 @@ def forward( sample = torch.cat([sample, hint], dim=1) else: aug_emb = self.unet_module.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) if self.unet_module.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb @@ -299,7 +301,8 @@ def forward( ) else: encoder_hidden_states = self.unet_module.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) # 2. pre-process sample = self.unet_module.conv_in(sample) @@ -378,18 +381,25 @@ def forward( **additional_residuals, ) else: - if diffusers_version < diffusers_0210_v or diffusers_version >= diffusers_0270_v: + if ( + diffusers_version < diffusers_0210_v + or diffusers_version >= diffusers_0270_v + ): sample, res_samples = downsample_block( hidden_states=sample, temb=emb, - exist_block_number=cache_block_id if i == cache_layer_id else None, + exist_block_number=cache_block_id + if i == cache_layer_id + else None, ) else: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, scale=lora_scale, - exist_block_number=cache_block_id if i == cache_layer_id else None, + exist_block_number=cache_block_id + if i == cache_layer_id + else None, ) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -410,7 +420,7 @@ def forward( ) down_block_res_samples = new_down_block_res_samples - + # No Middle # Up # print("down_block_res_samples:", [res_sample.shape for res_sample in down_block_res_samples]) @@ -470,7 +480,10 @@ def forward( else None, ) else: - if diffusers_version < diffusers_0210_v or diffusers_version >= diffusers_0270_v: + if ( + diffusers_version < diffusers_0210_v + or diffusers_version >= diffusers_0270_v + ): sample, _ = upsample_block( hidden_states=sample, temb=emb, diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_spatio_temporal_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_spatio_temporal_condition.py index cdae09b8c..dd7a6cf69 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_spatio_temporal_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/fast_unet_spatio_temporal_condition.py @@ -1,6 +1,7 @@ +from typing import Optional, Tuple, Union + import torch import torch.nn as nn -from typing import Union, Optional, Tuple from diffusers.utils import BaseOutput, logging from oneflow.nn.graph.proxy import ProxyModule diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/pipeline_utils.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/pipeline_utils.py index 2ff43592e..c2886fd32 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/pipeline_utils.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/pipeline_utils.py @@ -1,8 +1,8 @@ +import importlib +import importlib.metadata import os from packaging import version -import importlib -import importlib.metadata diffusers_0220_v = version.parse("0.22.0") diffusers_0240_v = version.parse("0.24.0") @@ -28,9 +28,9 @@ def get_class_obj_and_candidates( library_name = "onediffx.deep_cache.models.unet_2d_condition" if class_name == "UNetSpatioTemporalConditionModel": - assert diffusers_version >= diffusers_0240_v, ( - "SVD not support in diffusers-" + str(diffusers_version) - ) + assert ( + diffusers_version >= diffusers_0240_v + ), "SVD not support in diffusers-" + str(diffusers_version) library_name = ( "onediffx.deep_cache.models.unet_spatio_temporal_condition" ) @@ -45,7 +45,6 @@ def get_class_obj_and_candidates( return class_obj, class_candidates - else: def get_class_obj_and_candidates( @@ -79,9 +78,9 @@ def get_class_obj_and_candidates( library_name = "onediffx.deep_cache.models.unet_2d_condition" if class_name == "UNetSpatioTemporalConditionModel": - assert diffusers_version >= diffusers_0240_v, ( - "SVD not support in diffusers-" + str(diffusers_version) - ) + assert ( + diffusers_version >= diffusers_0240_v + ), "SVD not support in diffusers-" + str(diffusers_version) library_name = ( "onediffx.deep_cache.models.unet_spatio_temporal_condition" ) @@ -104,12 +103,17 @@ def get_class_obj_and_candidates( ORIGIN_3D_GET_UP_BLOCK = None if diffusers_version >= diffusers_0260_v: - from diffusers.models.unets import unet_2d_condition as diffusers_unet_2d_condition - from diffusers.models.unets import unet_spatio_temporal_condition as diffusers_unet_spatio_temporal_condition + from diffusers.models.unets import ( + unet_2d_condition as diffusers_unet_2d_condition, + unet_spatio_temporal_condition as diffusers_unet_spatio_temporal_condition, + ) else: from diffusers.models import unet_2d_condition as diffusers_unet_2d_condition + if diffusers_version >= diffusers_0240_v: - from diffusers.models import unet_spatio_temporal_condition as diffusers_unet_spatio_temporal_condition + from diffusers.models import ( + unet_spatio_temporal_condition as diffusers_unet_spatio_temporal_condition, + ) def enable_deep_cache_pipeline(): @@ -124,7 +128,7 @@ def enable_deep_cache_pipeline(): if diffusers_version >= diffusers_0240_v: assert ORIGIN_3D_GET_DOWN_BLOCK is None assert ORIGIN_3D_GET_UP_BLOCK is None - + if diffusers_version < diffusers_0270_v: ORIGIN_DIFFUDION_GET_CLC_OBJ_CANDIDATES = ( diffusers.pipelines.pipeline_utils.get_class_obj_and_candidates @@ -156,9 +160,7 @@ def enable_deep_cache_pipeline(): ORIGIN_3D_GET_DOWN_BLOCK = ( diffusers_unet_spatio_temporal_condition.get_down_block ) - diffusers_unet_spatio_temporal_condition.get_down_block = ( - get_3d_down_block - ) + diffusers_unet_spatio_temporal_condition.get_down_block = get_3d_down_block from .unet_3d_blocks import get_up_block as get_3d_up_block @@ -192,7 +194,9 @@ def disable_deep_cache_pipeline(): diffusers_unet_2d_condition.get_down_block = ORIGIN_2D_GET_DOWN_BLOCK diffusers_unet_2d_condition.get_up_block = ORIGIN_2D_GET_UP_BLOCK if diffusers_version >= diffusers_0240_v: - diffusers_unet_spatio_temporal_condition.get_down_block = ORIGIN_3D_GET_DOWN_BLOCK + diffusers_unet_spatio_temporal_condition.get_down_block = ( + ORIGIN_3D_GET_DOWN_BLOCK + ) diffusers_unet_spatio_temporal_condition.get_up_block = ORIGIN_3D_GET_UP_BLOCK diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py index 8888040b1..5641f1176 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_blocks.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +import importlib.metadata import types +from typing import Any, Dict, Optional, Tuple + import torch from oneflow.nn.graph.proxy import ProxyModule from packaging import version -import importlib.metadata diffusers_0210_v = version.parse("0.21.0") diffusers_0260_v = version.parse("0.26.0") @@ -37,6 +38,7 @@ if diffusers_version >= diffusers_0210_v: + class CrossAttnDownBlock2D(diffusers_unet_2d_blocks.CrossAttnDownBlock2D): def forward( self, @@ -72,11 +74,16 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -116,10 +123,13 @@ def custom_forward(*inputs): return hidden_states, output_states - class DownBlock2D(diffusers_unet_2d_blocks.DownBlock2D): def forward( - self, hidden_states, temb=None, scale: float = 1.0, exist_block_number=None, + self, + hidden_states, + temb=None, + scale: float = 1.0, + exist_block_number=None, ): # print("exist_block_number:", exist_block_number, type(self)) output_states = () @@ -162,7 +172,6 @@ def custom_forward(*inputs): return hidden_states, output_states - class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D): def forward( self, @@ -211,11 +220,16 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -240,7 +254,10 @@ def custom_forward(*inputs): for upsampler in self.upsamplers: if isinstance(self, ProxyModule): hidden_states = upsampler( - hidden_states, upsample_size, output_like=output_like, scale=lora_scale + hidden_states, + upsample_size, + output_like=output_like, + scale=lora_scale, ) else: hidden_states = upsampler( @@ -249,7 +266,6 @@ def custom_forward(*inputs): return hidden_states, prv_f - class UpBlock2D(diffusers_unet_2d_blocks.UpBlock2D): def forward( self, @@ -303,11 +319,21 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: if isinstance(self, ProxyModule): - hidden_states = upsampler(hidden_states, upsample_size, scale=scale, output_like=output_like,) + hidden_states = upsampler( + hidden_states, + upsample_size, + scale=scale, + output_like=output_like, + ) else: - hidden_states = upsampler(hidden_states, upsample_size, scale=scale,) + hidden_states = upsampler( + hidden_states, + upsample_size, + scale=scale, + ) return hidden_states, prv_f + else: class CrossAttnDownBlock2D(diffusers_unet_2d_blocks.CrossAttnDownBlock2D): @@ -326,7 +352,9 @@ def forward( if diffusers_version >= diffusers_0270_v: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) output_states = () @@ -344,11 +372,16 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -388,10 +421,12 @@ def custom_forward(*inputs): return hidden_states, output_states - class DownBlock2D(diffusers_unet_2d_blocks.DownBlock2D): def forward( - self, hidden_states, temb=None, exist_block_number=None, + self, + hidden_states, + temb=None, + exist_block_number=None, ): # print("exist_block_number:", exist_block_number, type(self)) output_states = () @@ -434,7 +469,6 @@ def custom_forward(*inputs): return hidden_states, output_states - class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D): def forward( self, @@ -453,7 +487,9 @@ def forward( if diffusers_version >= diffusers_0270_v: if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.") + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) prv_f = [] for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): @@ -482,11 +518,16 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, ) hidden_states = attn( hidden_states, @@ -514,13 +555,10 @@ def custom_forward(*inputs): hidden_states, upsample_size, output_like ) else: - hidden_states = upsampler( - hidden_states, upsample_size - ) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states, prv_f - class UpBlock2D(diffusers_unet_2d_blocks.UpBlock2D): def forward( self, @@ -573,7 +611,9 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: if isinstance(self, ProxyModule): - hidden_states = upsampler(hidden_states, upsample_size, output_like) + hidden_states = upsampler( + hidden_states, upsample_size, output_like + ) else: hidden_states = upsampler(hidden_states, upsample_size) @@ -581,10 +621,10 @@ def custom_forward(*inputs): update_cls = { - "CrossAttnDownBlock2D": CrossAttnDownBlock2D, - "DownBlock2D": DownBlock2D, - "CrossAttnUpBlock2D": CrossAttnUpBlock2D, - "UpBlock2D": UpBlock2D, + "CrossAttnDownBlock2D": CrossAttnDownBlock2D, + "DownBlock2D": DownBlock2D, + "CrossAttnUpBlock2D": CrossAttnUpBlock2D, + "UpBlock2D": UpBlock2D, } if diffusers_version >= diffusers_0260_v: @@ -594,11 +634,15 @@ def custom_forward(*inputs): src_get_down_block = diffusers.models.unet_2d_blocks.get_down_block src_get_up_block = diffusers.models.unet_2d_blocks.get_up_block -down_globals = {k : v for k, v in src_get_down_block.__globals__.items()} +down_globals = {k: v for k, v in src_get_down_block.__globals__.items()} down_globals.update(update_cls) -get_down_block = types.FunctionType(src_get_down_block.__code__, down_globals, argdefs=src_get_down_block.__defaults__) +get_down_block = types.FunctionType( + src_get_down_block.__code__, down_globals, argdefs=src_get_down_block.__defaults__ +) -up_globals = {k : v for k, v in src_get_up_block.__globals__.items()} +up_globals = {k: v for k, v in src_get_up_block.__globals__.items()} up_globals.update(update_cls) -get_up_block = types.FunctionType(src_get_up_block.__code__, up_globals, argdefs=src_get_up_block.__defaults__) +get_up_block = types.FunctionType( + src_get_up_block.__code__, up_globals, argdefs=src_get_up_block.__defaults__ +) diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py index 0204c9632..82ad04c4d 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_2d_condition.py @@ -11,28 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union -from packaging import version -import importlib.metadata from oneflow.nn.graph.proxy import ProxyModule +from packaging import version + diffusers_0210_v = version.parse("0.21.0") diffusers_0260_v = version.parse("0.26.0") diffusers_0270_v = version.parse("0.27.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) -import torch import diffusers +import torch +from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging -from diffusers.models.modeling_utils import ModelMixin if diffusers_version >= diffusers_0260_v: - from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel as DiffusersUNet2DConditionModel + from diffusers.models.unets.unet_2d_condition import ( + UNet2DConditionModel as DiffusersUNet2DConditionModel, + ) else: - from diffusers.models.unet_2d_condition import UNet2DConditionModel as DiffusersUNet2DConditionModel + from diffusers.models.unet_2d_condition import ( + UNet2DConditionModel as DiffusersUNet2DConditionModel, + ) try: USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND @@ -108,7 +113,7 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - default_overall_up_factor = 2 ** self.num_upsamplers + default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` # forward_upsample_size = False @@ -261,15 +266,17 @@ def forward( hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) - + else: aug_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) - + emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: @@ -323,7 +330,8 @@ def forward( ) else: encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) # 2. pre-process @@ -400,7 +408,10 @@ def forward( **additional_residuals, ) else: - if diffusers_version < diffusers_0210_v or diffusers_version >= diffusers_0270_v: + if ( + diffusers_version < diffusers_0210_v + or diffusers_version >= diffusers_0270_v + ): sample, res_samples = downsample_block( hidden_states=sample, temb=emb ) @@ -505,7 +516,10 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: - if diffusers_version < diffusers_0210_v or diffusers_version >= diffusers_0270_v: + if ( + diffusers_version < diffusers_0210_v + or diffusers_version >= diffusers_0270_v + ): sample, current_record_f = upsample_block( hidden_states=sample, temb=emb, diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py index 915483f4d..1e710f330 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_3d_blocks.py @@ -1,14 +1,14 @@ -from typing import Any, Dict, Optional, Tuple, Union +import importlib.metadata import types +from typing import Any, Dict, Optional, Tuple, Union from packaging import version -import importlib.metadata diffusers_0260_v = version.parse("0.26.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) -import torch import diffusers +import torch from diffusers.utils import is_torch_version @@ -54,7 +54,9 @@ def custom_forward(*inputs): ) else: hidden_states = resnet( - hidden_states, temb, image_only_indicator=image_only_indicator, + hidden_states, + temb, + image_only_indicator=image_only_indicator, ) output_states = output_states + (hidden_states,) @@ -74,7 +76,9 @@ def custom_forward(*inputs): return hidden_states, output_states -class CrossAttnDownBlockSpatioTemporal(diffusers_unet_3d_blocks.CrossAttnDownBlockSpatioTemporal): +class CrossAttnDownBlockSpatioTemporal( + diffusers_unet_3d_blocks.CrossAttnDownBlockSpatioTemporal +): def forward( self, hidden_states: torch.FloatTensor, @@ -99,9 +103,9 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -118,7 +122,9 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet( - hidden_states, temb, image_only_indicator=image_only_indicator, + hidden_states, + temb, + image_only_indicator=image_only_indicator, ) hidden_states = attn( hidden_states, @@ -189,7 +195,9 @@ def custom_forward(*inputs): ) else: hidden_states = resnet( - hidden_states, temb, image_only_indicator=image_only_indicator, + hidden_states, + temb, + image_only_indicator=image_only_indicator, ) if self.upsamplers is not None: @@ -199,7 +207,9 @@ def custom_forward(*inputs): return hidden_states, prv_f -class CrossAttnUpBlockSpatioTemporal(diffusers_unet_3d_blocks.CrossAttnUpBlockSpatioTemporal): +class CrossAttnUpBlockSpatioTemporal( + diffusers_unet_3d_blocks.CrossAttnUpBlockSpatioTemporal +): def forward( self, hidden_states: torch.FloatTensor, @@ -232,9 +242,9 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -250,7 +260,9 @@ def custom_forward(*inputs): )[0] else: hidden_states = resnet( - hidden_states, temb, image_only_indicator=image_only_indicator, + hidden_states, + temb, + image_only_indicator=image_only_indicator, ) hidden_states = attn( hidden_states, diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_spatio_temporal_condition.py b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_spatio_temporal_condition.py index a7cc9c186..e84a6a435 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_spatio_temporal_condition.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/models/unet_spatio_temporal_condition.py @@ -1,10 +1,11 @@ +import importlib.metadata from dataclasses import dataclass from typing import Optional, Tuple, Union -from packaging import version -import importlib.metadata from oneflow.nn.graph.proxy import ProxyModule +from packaging import version + diffusers_0260_v = version.parse("0.26.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) @@ -18,9 +19,9 @@ ) import torch +from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput, logging -from diffusers.models.modeling_utils import ModelMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion.py b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion.py index 9e8ab30e1..74e7b017c 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata from typing import Any, Callable, Dict, List, Optional, Union from packaging import version -import importlib.metadata diffusers_0202_v = version.parse("0.20.2") diffusers_0214_v = version.parse("0.21.4") @@ -26,38 +26,38 @@ diffusers_0270_v = version.parse("0.27.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) -import torch import numpy as np +import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer if diffusers_version >= diffusers_0240_v: - from transformers import CLIPVisionModelWithProjection from diffusers.image_processor import PipelineImageInput - from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + retrieve_timesteps, + ) + from transformers import CLIPVisionModelWithProjection +from diffusers import StableDiffusionPipeline as DiffusersStableDiffusionPipeline from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - deprecate, - logging, -) from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + rescale_noise_cfg, +) from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import deprecate, logging -from diffusers import StableDiffusionPipeline as DiffusersStableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg - -from .models.unet_2d_condition import UNet2DConditionModel from .models.fast_unet_2d_condition import FastUNet2DConditionModel - from .models.pipeline_utils import enable_deep_cache_pipeline +from .models.unet_2d_condition import UNet2DConditionModel + enable_deep_cache_pipeline() @@ -70,7 +70,7 @@ def sample_from_quad(total_numbers, n_samples, pow=1.2): x_values = np.linspace(0, total_numbers ** (1 / pow), n_samples + 1) # Raise these values to the power of 1.5 to get a non-linear distribution - indices = np.unique(np.int32(x_values ** pow))[:-1] + indices = np.unique(np.int32(x_values**pow))[:-1] if len(indices) == n_samples: break pow -= 0.02 @@ -87,7 +87,7 @@ def sample_from_quad_center(total_numbers, n_samples, center, pow=1.2): x_values = np.linspace( (-center) ** (1 / pow), (total_numbers - center) ** (1 / pow), n_samples + 1 ) - indices = [0] + [x + center for x in np.unique(np.int32(x_values ** pow))[1:-1]] + indices = [0] + [x + center for x in np.unique(np.int32(x_values**pow))[1:-1]] if len(indices) == n_samples: break pow -= 0.02 @@ -97,7 +97,9 @@ def sample_from_quad_center(total_numbers, n_samples, center, pow=1.2): ) return indices, pow + if diffusers_version <= diffusers_0214_v: + class StableDiffusionPipeline(DiffusersStableDiffusionPipeline): _optional_components = ["safety_checker", "feature_extractor"] if diffusers_version > diffusers_0202_v: @@ -115,7 +117,10 @@ def __init__( feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -124,12 +129,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -137,7 +147,12 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -158,10 +173,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -174,7 +195,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -189,10 +212,12 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.fast_unet = FastUNet2DConditionModel(self.unet) - + @torch.no_grad() def __call__( self, @@ -227,7 +252,13 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, ) # 2. Define call parameters @@ -246,7 +277,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) if diffusers_version > diffusers_0202_v: prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -297,7 +330,9 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = ( + len(timesteps) - num_inference_steps * self.scheduler.order + ) prv_features = None latents_list = [latents] @@ -318,12 +353,18 @@ def __call__( # interval_seq, pow = sample_from_quad(num_inference_steps, num_inference_steps//cache_interval, pow=pow)#[0, 3, 6, 9, 12, 16, 22, 28, 35, 43,] interval_seq = sorted(interval_seq) - + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if i in interval_seq or cache_interval == 1: prv_features = None @@ -351,27 +392,42 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -381,22 +437,30 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) if diffusers_version > diffusers_0202_v: # Offload all models self.maybe_free_model_hooks() else: # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + if ( + hasattr(self, "final_offload_hook") + and self.final_offload_hook is not None + ): self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) elif diffusers_version <= diffusers_0231_v: + class StableDiffusionPipeline(DiffusersStableDiffusionPipeline): _optional_components = ["safety_checker", "feature_extractor"] model_cpu_offload_seq = "text_encoder->unet->vae" @@ -414,7 +478,10 @@ def __init__( feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -423,12 +490,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -436,7 +508,12 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -457,10 +534,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -473,7 +556,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -488,14 +573,16 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.fast_unet = FastUNet2DConditionModel(self.unet) - + @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -571,7 +658,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -615,13 +704,18 @@ def __call__( # 6.5 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor( + self.guidance_scale - 1 + ).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = ( + len(timesteps) - num_inference_steps * self.scheduler.order + ) self._num_timesteps = len(timesteps) prv_features = None @@ -643,12 +737,18 @@ def __call__( # interval_seq, pow = sample_from_quad(num_inference_steps, num_inference_steps//cache_interval, pow=pow)#[0, 3, 6, 9, 12, 16, 22, 28, 35, 43,] interval_seq = sorted(interval_seq) - + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if i in interval_seq or cache_interval == 1: prv_features = None @@ -702,43 +802,66 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": if diffusers_version > diffusers_0223_v: - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] else: - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -748,22 +871,30 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) if diffusers_version > diffusers_0202_v: # Offload all models self.maybe_free_model_hooks() else: # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + if ( + hasattr(self, "final_offload_hook") + and self.final_offload_hook is not None + ): self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) elif diffusers_version < diffusers_0270_v: + class StableDiffusionPipeline(DiffusersStableDiffusionPipeline): if diffusers_version > diffusers_0240_v: model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" @@ -785,7 +916,10 @@ def __init__( image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -794,12 +928,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -807,7 +946,12 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -828,10 +972,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -844,7 +994,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -860,14 +1012,16 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.fast_unet = FastUNet2DConditionModel(self.unet) - + @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -947,7 +1101,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -974,17 +1130,28 @@ def __call__( ) else: if diffusers_version > diffusers_0240_v: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + output_hidden_state = ( + False + if isinstance(self.unet.encoder_hid_proj, ImageProjection) + else True + ) image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ip_adapter_image, + device, + num_images_per_prompt, + output_hidden_state, ) else: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1002,20 +1169,26 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} if ip_adapter_image is not None else None + ) # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = ( + len(timesteps) - num_inference_steps * self.scheduler.order + ) self._num_timesteps = len(timesteps) prv_features = None @@ -1037,15 +1210,21 @@ def __call__( # interval_seq, pow = sample_from_quad(num_inference_steps, num_inference_steps//cache_interval, pow=pow)#[0, 3, 6, 9, 12, 16, 22, 28, 35, 43,] interval_seq = sorted(interval_seq) - + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if diffusers_version > diffusers_0240_v: if self.interrupt: continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if i in interval_seq or cache_interval == 1: prv_features = None @@ -1103,43 +1282,66 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": if diffusers_version > diffusers_0223_v: - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] else: - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -1149,20 +1351,28 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) if diffusers_version > diffusers_0202_v: # Offload all models self.maybe_free_model_hooks() else: # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + if ( + hasattr(self, "final_offload_hook") + and self.final_offload_hook is not None + ): self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) + else: class StableDiffusionPipeline(DiffusersStableDiffusionPipeline): @@ -1183,7 +1393,10 @@ def __init__( image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -1192,12 +1405,17 @@ def __init__( " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + if ( + hasattr(scheduler.config, "clip_sample") + and scheduler.config.clip_sample is True + ): deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" @@ -1205,7 +1423,12 @@ def __init__( " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "clip_sample not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) @@ -1226,10 +1449,16 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + is_unet_version_less_0_9_0 = hasattr( + unet.config, "_diffusers_version" + ) and version.parse( version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) < version.parse( + "0.9.0.dev0" + ) + is_unet_sample_size_less_64 = ( + hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -1242,7 +1471,9 @@ def __init__( " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + deprecate( + "sample_size<64", "1.0.0", deprecation_message, standard_warn=False + ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) @@ -1258,14 +1489,16 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.fast_unet = FastUNet2DConditionModel(self.unet) @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -1347,7 +1580,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -1377,7 +1612,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -1395,7 +1632,6 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} @@ -1406,13 +1642,18 @@ def __call__( # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = ( + len(timesteps) - num_inference_steps * self.scheduler.order + ) self._num_timesteps = len(timesteps) prv_features = None @@ -1434,14 +1675,20 @@ def __call__( # interval_seq, pow = sample_from_quad(num_inference_steps, num_inference_steps//cache_interval, pow=pow)#[0, 3, 6, 9, 12, 16, 22, 28, 35, 43,] interval_seq = sorted(interval_seq) - + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) if i in interval_seq or cache_interval == 1: prv_features = None @@ -1473,40 +1720,61 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, + return_dict=False, + generator=generator, + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -1516,7 +1784,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload all models self.maybe_free_model_hooks() @@ -1524,4 +1794,6 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion_xl.py b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion_xl.py index bec571a18..55962a185 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion_xl.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_diffusion_xl.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import importlib.metadata from packaging import version diffusers_0202_v = version.parse("0.20.2") @@ -27,22 +27,22 @@ diffusers_version = version.parse(importlib.metadata.version("diffusers")) import torch -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import StableDiffusionXLPipeline as DiffusersStableDiffusionXLPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - is_invisible_watermark_available, - logging, -) from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + rescale_noise_cfg, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import is_invisible_watermark_available, logging +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from diffusers import StableDiffusionXLPipeline as DiffusersStableDiffusionXLPipeline -from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from .models.fast_unet_2d_condition import FastUNet2DConditionModel from .models.unet_2d_condition import UNet2DConditionModel -from .models.fast_unet_2d_condition import FastUNet2DConditionModel if diffusers_version > diffusers_0214_v: from diffusers.utils import is_torch_xla_available @@ -55,9 +55,11 @@ XLA_AVAILABLE = False if diffusers_version >= diffusers_0240_v: - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from diffusers.image_processor import PipelineImageInput - from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps + from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( + retrieve_timesteps, + ) + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from .models.pipeline_utils import enable_deep_cache_pipeline @@ -72,13 +74,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + def sample_from_quad_center(total_numbers, n_samples, center, pow=1.2): while pow > 1: # Generate linearly spaced values between 0 and a max value x_values = np.linspace( (-center) ** (1 / pow), (total_numbers - center) ** (1 / pow), n_samples + 1 ) - indices = [0] + [x + center for x in np.unique(np.int32(x_values ** pow))[1:-1]] + indices = [0] + [x + center for x in np.unique(np.int32(x_values**pow))[1:-1]] if len(indices) == n_samples: break pow -= 0.02 @@ -90,6 +93,7 @@ def sample_from_quad_center(total_numbers, n_samples, center, pow=1.2): if diffusers_version <= diffusers_0202_v: + class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): def __init__( self, @@ -112,12 +116,20 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) self.fast_unet = FastUNet2DConditionModel(self.unet) @@ -132,11 +144,11 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + def upcast_vae(self): super().upcast_vae() self.vae_upcasted = True - + @torch.no_grad() def __call__( self, @@ -212,7 +224,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) ( prompt_embeds, @@ -258,32 +272,50 @@ def __call__( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, ) if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 7.1 Apply denoising_end - if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + if ( + denoising_end is not None + and type(denoising_end) == float + and denoising_end > 0 + and denoising_end < 1 + ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] - + if cache_interval == 1: interval_seq = list(range(num_inference_steps)) else: @@ -300,12 +332,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if i in interval_seq or cache_interval == 1: prv_features = None @@ -334,21 +375,32 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -360,7 +412,9 @@ def __call__( latents = latents.to(dtype) if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] else: image = latents return StableDiffusionXLPipelineOutput(images=image) @@ -372,7 +426,10 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + if ( + hasattr(self, "final_offload_hook") + and self.final_offload_hook is not None + ): self.final_offload_hook.offload() if not return_dict: @@ -381,8 +438,10 @@ def __call__( return StableDiffusionXLPipelineOutput(images=image) elif diffusers_version <= diffusers_0214_v: + class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + def __init__( self, vae: AutoencoderKL, @@ -404,12 +463,20 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) self.fast_unet = FastUNet2DConditionModel(self.unet) @@ -424,11 +491,11 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + def upcast_vae(self): super().upcast_vae() self.vae_upcasted = True - + @torch.no_grad() def __call__( self, @@ -507,7 +574,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) ( prompt_embeds, @@ -553,7 +622,10 @@ def __call__( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, ) if negative_original_size is not None and negative_target_size is not None: @@ -567,28 +639,43 @@ def __call__( negative_add_time_ids = add_time_ids if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 7.1 Apply denoising_end - if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + if ( + denoising_end is not None + and isinstance(denoising_end, float) + and denoising_end > 0 + and denoising_end < 1 + ): discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] - + if cache_interval == 1: interval_seq = list(range(num_inference_steps)) else: @@ -605,12 +692,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if i in interval_seq or cache_interval == 1: prv_features = None @@ -639,21 +735,32 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -665,7 +772,9 @@ def __call__( dtype = next(iter(self.vae.post_quant_conv.parameters())).dtype latents = latents.to(dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] else: image = latents @@ -684,10 +793,17 @@ def __call__( return (image,) return StableDiffusionXLPipelineOutput(images=image) + elif diffusers_version <= diffusers_0231_v: + class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -697,6 +813,7 @@ class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): "negative_pooled_prompt_embeds", "negative_add_time_ids", ] + def __init__( self, vae: AutoencoderKL, @@ -718,12 +835,20 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) self.fast_unet = FastUNet2DConditionModel(self.unet) @@ -738,11 +863,11 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + def upcast_vae(self): super().upcast_vae() self.vae_upcasted = True - + @torch.no_grad() def __call__( self, @@ -840,7 +965,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( @@ -911,16 +1038,24 @@ def __call__( negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 8.1 Apply denoising_end if ( @@ -932,21 +1067,29 @@ def __call__( discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + - ( + self.denoising_end + * self.scheduler.config.num_train_timesteps + ) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] - + if diffusers_version > diffusers_0223_v: # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor( + self.guidance_scale - 1 + ).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) - + self._num_timesteps = len(timesteps) if cache_interval == 1: interval_seq = list(range(num_inference_steps)) @@ -964,12 +1107,21 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if i in interval_seq or cache_interval == 1: prv_features = None @@ -1026,37 +1178,61 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, + ) + add_time_ids = callback_outputs.pop( + "add_time_ids", add_time_ids + ) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1072,7 +1248,9 @@ def __call__( dtype = next(iter(self.vae.post_quant_conv.parameters())).dtype latents = latents.to(dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] else: image = latents @@ -1093,9 +1271,12 @@ def __call__( return StableDiffusionXLPipelineOutput(images=image) elif diffusers_version < diffusers_0270_v: + class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): if diffusers_version > diffusers_0240_v: - model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + model_cpu_offload_seq = ( + "text_encoder->text_encoder_2->image_encoder->unet->vae" + ) else: model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" @@ -1116,6 +1297,7 @@ class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): "negative_pooled_prompt_embeds", "negative_add_time_ids", ] + def __init__( self, vae: AutoencoderKL, @@ -1141,12 +1323,20 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) self.fast_unet = FastUNet2DConditionModel(self.unet) @@ -1161,11 +1351,11 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + def upcast_vae(self): super().upcast_vae() self.vae_upcasted = True - + @torch.no_grad() def __call__( self, @@ -1267,7 +1457,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( @@ -1292,7 +1484,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) timesteps = self.scheduler.timesteps @@ -1338,13 +1532,19 @@ def __call__( negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) if ip_adapter_image is not None: if diffusers_version > diffusers_0251_v: @@ -1353,18 +1553,29 @@ def __call__( ) else: if diffusers_version > diffusers_0240_v: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + output_hidden_state = ( + False + if isinstance(self.unet.encoder_hid_proj, ImageProjection) + else True + ) image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ip_adapter_image, + device, + num_images_per_prompt, + output_hidden_state, ) else: - image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt) + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt + ) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) image_embeds = image_embeds.to(device) - + # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 8.1 Apply denoising_end if ( @@ -1376,21 +1587,29 @@ def __call__( discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + - ( + self.denoising_end + * self.scheduler.config.num_train_timesteps + ) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] - + if diffusers_version > diffusers_0223_v: # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor( + self.guidance_scale - 1 + ).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) - + self._num_timesteps = len(timesteps) if cache_interval == 1: interval_seq = list(range(num_inference_steps)) @@ -1412,12 +1631,21 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } if ip_adapter_image is not None: added_cond_kwargs["image_embeds"] = image_embeds @@ -1476,37 +1704,61 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, + ) + add_time_ids = callback_outputs.pop( + "add_time_ids", add_time_ids + ) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1522,7 +1774,9 @@ def __call__( dtype = next(iter(self.vae.post_quant_conv.parameters())).dtype latents = latents.to(dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] else: image = latents @@ -1541,7 +1795,9 @@ def __call__( return (image,) return StableDiffusionXLPipelineOutput(images=image) + else: + class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" @@ -1562,6 +1818,7 @@ class StableDiffusionXLPipeline(DiffusersStableDiffusionXLPipeline): "negative_pooled_prompt_embeds", "negative_add_time_ids", ] + def __init__( self, vae: AutoencoderKL, @@ -1587,12 +1844,20 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.register_to_config( + force_zeros_for_empty_prompt=force_zeros_for_empty_prompt + ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) self.default_sample_size = self.unet.config.sample_size - add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + add_watermarker = ( + add_watermarker + if add_watermarker is not None + else is_invisible_watermark_available() + ) self.fast_unet = FastUNet2DConditionModel(self.unet) @@ -1607,11 +1872,11 @@ def __init__( self.watermark = StableDiffusionXLWatermarker() else: self.watermark = None - + def upcast_vae(self): super().upcast_vae() self.vae_upcasted = True - + @torch.no_grad() def __call__( self, @@ -1715,7 +1980,9 @@ def __call__( # 3. Encode input prompt lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + self.cross_attention_kwargs.get("scale", None) + if self.cross_attention_kwargs is not None + else None ) ( @@ -1740,7 +2007,9 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps + ) timesteps = self.scheduler.timesteps @@ -1786,13 +2055,19 @@ def __call__( negative_add_time_ids = add_time_ids if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + add_text_embeds = torch.cat( + [negative_pooled_prompt_embeds, add_text_embeds], dim=0 + ) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device).repeat( + batch_size * num_images_per_prompt, 1 + ) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -1802,9 +2077,11 @@ def __call__( batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) - + # 8. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) # 8.1 Apply denoising_end if ( @@ -1816,20 +2093,28 @@ def __call__( discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + - ( + self.denoising_end + * self.scheduler.config.num_train_timesteps + ) ) ) - num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + num_inference_steps = len( + list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)) + ) timesteps = timesteps[:num_inference_steps] - + # 9. Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat( + batch_size * num_images_per_prompt + ) timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + guidance_scale_tensor, + embedding_dim=self.unet.config.time_cond_proj_dim, ).to(device=device, dtype=latents.dtype) - + self._num_timesteps = len(timesteps) if cache_interval == 1: interval_seq = list(range(num_inference_steps)) @@ -1850,13 +2135,25 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} - if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + } + if ( + ip_adapter_image is not None + or ip_adapter_image_embeds is not None + ): added_cond_kwargs["image_embeds"] = image_embeds if i in interval_seq or cache_interval == 1: @@ -1875,7 +2172,7 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + else: noise_pred, prv_features = self.fast_unet( latent_model_input, @@ -1889,37 +2186,61 @@ def __call__( cache_block_id=cache_block_id, return_dict=False, ) - + # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, + noise_pred_text, + guidance_rescale=self.guidance_rescale, + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + add_text_embeds = callback_outputs.pop( + "add_text_embeds", add_text_embeds + ) negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + "negative_pooled_prompt_embeds", + negative_pooled_prompt_embeds, + ) + add_time_ids = callback_outputs.pop( + "add_time_ids", add_time_ids + ) + negative_add_time_ids = callback_outputs.pop( + "negative_add_time_ids", negative_add_time_ids ) - add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -1937,16 +2258,29 @@ def __call__( # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + has_latents_mean = ( + hasattr(self.vae.config, "latents_mean") + and self.vae.config.latents_mean is not None + ) + has_latents_std = ( + hasattr(self.vae.config, "latents_std") + and self.vae.config.latents_std is not None + ) if has_latents_mean and has_latents_std: latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + torch.tensor(self.vae.config.latents_mean) + .view(1, 4, 1, 1) + .to(latents.device, latents.dtype) ) latents_std = ( - torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + torch.tensor(self.vae.config.latents_std) + .view(1, 4, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = ( + latents * latents_std / self.vae.config.scaling_factor + + latents_mean ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor diff --git a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_video_diffusion.py b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_video_diffusion.py index 7314e715b..facc83b9f 100644 --- a/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_video_diffusion.py +++ b/onediff_diffusers_extensions/onediffx/deep_cache/pipeline_stable_video_diffusion.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata from typing import Callable, Dict, List, Optional, Union from packaging import version -import importlib.metadata diffusers_0240_v = version.parse("0.24.0") diffusers_0251_v = version.parse("0.25.1") @@ -25,30 +25,31 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKLTemporalDecoder from diffusers.schedulers import EulerDiscreteScheduler from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import randn_tensor +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection logger = logging.get_logger(__name__) # pylint: disable=invalid-name -from diffusers import StableVideoDiffusionPipeline as DiffusersStableVideoDiffusionPipeline +from diffusers import ( + StableVideoDiffusionPipeline as DiffusersStableVideoDiffusionPipeline, +) from diffusers.pipelines.stable_video_diffusion import ( StableVideoDiffusionPipelineOutput, ) - -from .models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .models.fast_unet_spatio_temporal_condition import ( FastUNetSpatioTemporalConditionModel, ) - from .models.pipeline_utils import enable_deep_cache_pipeline +from .models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + enable_deep_cache_pipeline() @@ -154,12 +155,14 @@ def __call__( # corresponds to doing no classifier free guidance. if diffusers_version > diffusers_0240_v: self._guidance_scale = max_guidance_scale - else: + else: do_classifier_free_guidance = max_guidance_scale > 1.0 # 3. Encode input image if diffusers_version > diffusers_0240_v: - image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + image_embeddings = self._encode_image( + image, device, num_videos_per_prompt, self.do_classifier_free_guidance + ) else: image_embeddings = self._encode_image( image, device, num_videos_per_prompt, do_classifier_free_guidance @@ -172,8 +175,12 @@ def __call__( # 4. Encode input image using VAE if diffusers_version > diffusers_0251_v: - image = self.image_processor.preprocess(image, height=height, width=width).to(device) - noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = self.image_processor.preprocess( + image, height=height, width=width + ).to(device) + noise = randn_tensor( + image.shape, generator=generator, device=device, dtype=image.dtype + ) else: image = self.image_processor.preprocess(image, height=height, width=width) noise = randn_tensor( @@ -195,7 +202,9 @@ def __call__( do_classifier_free_guidance=self.do_classifier_free_guidance, ) if diffusers_version > diffusers_0240_v: - image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + image_latents = self._encode_vae_image( + image, device, num_videos_per_prompt, self.do_classifier_free_guidance + ) else: image_latents = self._encode_vae_image( image, device, num_videos_per_prompt, do_classifier_free_guidance @@ -272,10 +281,16 @@ def __call__( for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance if diffusers_version > diffusers_0240_v: - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) else: latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t @@ -313,7 +328,9 @@ def __call__( if diffusers_version > diffusers_0240_v: if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) else: if do_classifier_free_guidance: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) diff --git a/onediff_diffusers_extensions/onediffx/lora/__init__.py b/onediff_diffusers_extensions/onediffx/lora/__init__.py index 5d99001bc..0fe74f159 100644 --- a/onediff_diffusers_extensions/onediffx/lora/__init__.py +++ b/onediff_diffusers_extensions/onediffx/lora/__init__.py @@ -1,9 +1,11 @@ +from onediff.infer_compiler.backends.oneflow.param_utils import ( + update_graph_with_constant_folding_info, +) + from .lora import ( - load_and_fuse_lora, - unfuse_lora, - set_and_fuse_adapters, delete_adapters, get_active_adapters, + load_and_fuse_lora, + set_and_fuse_adapters, + unfuse_lora, ) - -from onediff.infer_compiler.backends.oneflow.param_utils import update_graph_with_constant_folding_info diff --git a/onediff_diffusers_extensions/onediffx/lora/lora.py b/onediff_diffusers_extensions/onediffx/lora/lora.py index 8e7896094..e5c9ee0d2 100644 --- a/onediff_diffusers_extensions/onediffx/lora/lora.py +++ b/onediff_diffusers_extensions/onediffx/lora/lora.py @@ -1,14 +1,14 @@ +from collections import defaultdict, OrderedDict from pathlib import Path -from typing import Optional, Union, Dict, Tuple, List -from collections import OrderedDict, defaultdict -from packaging import version +from typing import Dict, List, Optional, Tuple, Union + +import diffusers import torch +from diffusers.loaders import LoraLoaderMixin from onediff.utils import logger - -import diffusers -from diffusers.loaders import LoraLoaderMixin +from packaging import version if version.parse(diffusers.__version__) >= version.parse("0.21.0"): from diffusers.models.lora import PatchedLoraProjection @@ -16,19 +16,21 @@ from diffusers.loaders import PatchedLoraProjection +from .text_encoder import load_lora_into_text_encoder +from .unet import load_lora_into_unet from .utils import ( - _unfuse_lora, - _set_adapter, _delete_adapter, _maybe_map_sgm_blocks_to_diffusers, + _set_adapter, + _unfuse_lora, is_peft_available, ) -from .text_encoder import load_lora_into_text_encoder -from .unet import load_lora_into_unet if is_peft_available(): import peft -is_onediffx_lora_available = version.parse(diffusers.__version__) >= version.parse("0.19.3") +is_onediffx_lora_available = version.parse(diffusers.__version__) >= version.parse( + "0.19.3" +) USE_PEFT_BACKEND = False @@ -57,15 +59,21 @@ def load_and_fuse_lora( if use_cache: state_dict, network_alphas = load_state_dict_cached( - pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, ) else: # for diffusers <= 0.20 if hasattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers"): orig_func = getattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers") - LoraLoaderMixin._map_sgm_blocks_to_diffusers = _maybe_map_sgm_blocks_to_diffusers + LoraLoaderMixin._map_sgm_blocks_to_diffusers = ( + _maybe_map_sgm_blocks_to_diffusers + ) state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( - pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, **kwargs, + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, ) if hasattr(LoraLoaderMixin, "_map_sgm_blocks_to_diffusers"): LoraLoaderMixin._map_sgm_blocks_to_diffusers = orig_func @@ -87,7 +95,9 @@ def load_and_fuse_lora( ) # load lora weights into text encoder - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + text_encoder_state_dict = { + k: v for k, v in state_dict.items() if "text_encoder." in k + } if len(text_encoder_state_dict) > 0: load_lora_into_text_encoder( self, @@ -100,7 +110,9 @@ def load_and_fuse_lora( _pipeline=self, ) - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + text_encoder_2_state_dict = { + k: v for k, v in state_dict.items() if "text_encoder_2." in k + } if len(text_encoder_2_state_dict) > 0 and hasattr(self, "text_encoder_2"): load_lora_into_text_encoder( self, @@ -119,7 +131,8 @@ def _unfuse_lora_apply(m: torch.nn.Module): if isinstance(m, (torch.nn.Linear, PatchedLoraProjection, torch.nn.Conv2d)): _unfuse_lora(m) elif is_peft_available() and isinstance( - m, (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), + m, + (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), ): _unfuse_lora(m.base_layer) @@ -142,19 +155,26 @@ def set_and_fuse_adapters( adapter_names = [adapter_names] if adapter_weights is None: - adapter_weights = [1.0, ] * len(adapter_names) + adapter_weights = [ + 1.0, + ] * len(adapter_names) elif isinstance(adapter_weights, float): - adapter_weights = [adapter_weights, ] * len(adapter_names) + adapter_weights = [ + adapter_weights, + ] * len(adapter_names) _init_adapters_info(pipeline) pipeline._adapter_names |= set(adapter_names) - pipeline._active_adapter_names = {k: v for k, v in zip(adapter_names, adapter_weights)} + pipeline._active_adapter_names = { + k: v for k, v in zip(adapter_names, adapter_weights) + } def set_adapters_apply(m): if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d, PatchedLoraProjection)): _set_adapter(m, adapter_names, adapter_weights) elif is_peft_available() and isinstance( - m, (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), + m, + (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), ): _set_adapter(m.base_layer, adapter_names, adapter_weights) @@ -179,7 +199,8 @@ def delete_adapters_apply(m): if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d, PatchedLoraProjection)): _delete_adapter(m, adapter_names) elif is_peft_available() and isinstance( - m, (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), + m, + (peft.tuners.lora.layer.Linear, peft.tuners.lora.layer.Conv2d), ): _delete_adapter(m.base_layer, adapter_names) @@ -221,7 +242,8 @@ def __setitem__(self, key, value): def load_state_dict_cached( - lora: Union[str, Path, Dict[str, torch.Tensor]], **kwargs, + lora: Union[str, Path, Dict[str, torch.Tensor]], + **kwargs, ) -> Tuple[Dict, Dict]: assert isinstance(lora, (str, Path, dict)) if isinstance(lora, dict): @@ -233,10 +255,15 @@ def load_state_dict_cached( lora_name = str(lora) + (f"/{weight_name}" if weight_name else "") if lora_name in CachedLoRAs: - logger.debug(f"[OneDiffX Cached LoRA] get cached lora of name: {str(lora_name)}") + logger.debug( + f"[OneDiffX Cached LoRA] get cached lora of name: {str(lora_name)}" + ) return CachedLoRAs[lora_name] - state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(lora, **kwargs,) + state_dict, network_alphas = LoraLoaderMixin.lora_state_dict( + lora, + **kwargs, + ) CachedLoRAs[lora_name] = (state_dict, network_alphas) logger.debug(f"[OneDiffX Cached LoRA] create cached lora of name: {str(lora_name)}") return state_dict, network_alphas diff --git a/onediff_diffusers_extensions/onediffx/lora/text_encoder.py b/onediff_diffusers_extensions/onediffx/lora/text_encoder.py index df8f17ebe..dbea9c7cf 100644 --- a/onediff_diffusers_extensions/onediffx/lora/text_encoder.py +++ b/onediff_diffusers_extensions/onediffx/lora/text_encoder.py @@ -1,9 +1,10 @@ from collections import defaultdict -from packaging import version -import torch import diffusers +import torch +from packaging import version + if version.parse(diffusers.__version__) >= version.parse("0.22.0"): from diffusers.utils import convert_state_dict_to_diffusers else: @@ -16,9 +17,8 @@ ) else: from diffusers.loaders import text_encoder_attn_modules, text_encoder_mlp_modules -from diffusers.utils import is_accelerate_available - from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from diffusers.utils import is_accelerate_available from onediff.utils import logger from .utils import fuse_lora, get_adapter_names @@ -68,7 +68,9 @@ def load_lora_into_text_encoder( `default_{i}` where i is the total number of adapters being loaded. """ low_cpu_mem_usage = ( - low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT + low_cpu_mem_usage + if low_cpu_mem_usage is not None + else _LOW_CPU_MEM_USAGE_DEFAULT ) if adapter_name is None: @@ -95,9 +97,13 @@ def load_lora_into_text_encoder( # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_keys = [ + k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix + ] text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + k.replace(f"{prefix}.", ""): v + for k, v in state_dict.items() + if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: @@ -117,26 +123,40 @@ def load_lora_into_text_encoder( rank_key = f"{name}.out_proj.lora_B.weight" rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + patch_mlp = any( + ".mlp." in key for key in text_encoder_lora_state_dict.keys() + ) if patch_mlp: for name, _ in text_encoder_mlp_modules(text_encoder): rank_key_fc1 = f"{name}.fc1.lora_B.weight" rank_key_fc2 = f"{name}.fc2.lora_B.weight" - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + rank[rank_key_fc1] = text_encoder_lora_state_dict[ + rank_key_fc1 + ].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[ + rank_key_fc2 + ].shape[1] else: for name, _ in text_encoder_attn_modules(text_encoder): rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" - rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) + rank.update( + {rank_key: text_encoder_lora_state_dict[rank_key].shape[1]} + ) - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) + patch_mlp = any( + ".mlp." in key for key in text_encoder_lora_state_dict.keys() + ) if patch_mlp: for name, _ in text_encoder_mlp_modules(text_encoder): rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + rank[rank_key_fc1] = text_encoder_lora_state_dict[ + rank_key_fc1 + ].shape[1] + rank[rank_key_fc2] = text_encoder_lora_state_dict[ + rank_key_fc2 + ].shape[1] # group text encoder lora state_dict te_lora_grouped_dict = defaultdict(dict) @@ -194,13 +214,23 @@ def load_lora_into_text_encoder( is_network_alphas_populated = len(network_alphas) > 0 for name, attn_module in text_encoder_attn_modules(text_encoder): - query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) - key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) - value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) - out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) + query_alpha = network_alphas.pop( + name + ".to_q_lora.down.weight.alpha", None + ) + key_alpha = network_alphas.pop( + name + ".to_k_lora.down.weight.alpha", None + ) + value_alpha = network_alphas.pop( + name + ".to_v_lora.down.weight.alpha", None + ) + out_alpha = network_alphas.pop( + name + ".to_out_lora.down.weight.alpha", None + ) if isinstance(rank, dict): - current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") + current_rank = rank.pop( + f"{name}.out_proj.lora_linear_layer.up.weight" + ) else: current_rank = rank @@ -250,8 +280,12 @@ def load_lora_into_text_encoder( name + ".fc2.lora_linear_layer.down.weight.alpha", None ) - current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") - current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") + current_rank_fc1 = rank.pop( + f"{name}.fc1.lora_linear_layer.up.weight" + ) + current_rank_fc2 = rank.pop( + f"{name}.fc2.lora_linear_layer.up.weight" + ) fuse_lora( mlp_module.fc1, diff --git a/onediff_diffusers_extensions/onediffx/lora/unet.py b/onediff_diffusers_extensions/onediffx/lora/unet.py index 98834eeaa..53809eb7a 100644 --- a/onediff_diffusers_extensions/onediffx/lora/unet.py +++ b/onediff_diffusers_extensions/onediffx/lora/unet.py @@ -1,16 +1,14 @@ -from packaging import version -from typing import Union, Dict from collections import defaultdict +from typing import Dict, Union import torch +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.utils import is_accelerate_available from onediff.infer_compiler import DeployableModule from onediff.utils import logger -from diffusers.models.lora import ( - LoRACompatibleConv, - LoRACompatibleLinear, -) +from packaging import version + from .utils import fuse_lora, get_adapter_names, is_peft_available -from diffusers.utils import is_accelerate_available if is_peft_available(): import peft diff --git a/onediff_diffusers_extensions/onediffx/lora/utils.py b/onediff_diffusers_extensions/onediffx/lora/utils.py index eeb838e16..282766667 100644 --- a/onediff_diffusers_extensions/onediffx/lora/utils.py +++ b/onediff_diffusers_extensions/onediffx/lora/utils.py @@ -1,11 +1,12 @@ import os -from typing import Dict, Union, List -from packaging import version from collections import OrderedDict +from typing import Dict, List, Union -import torch import diffusers +import torch +from packaging import version + if version.parse(diffusers.__version__) >= version.parse("0.22.0"): from diffusers.utils.import_utils import is_peft_available @@ -14,7 +15,9 @@ else: is_peft_available = lambda: False -from onediff.infer_compiler.backends.oneflow.param_utils import update_graph_related_tensor +from onediff.infer_compiler.backends.oneflow.param_utils import ( + update_graph_related_tensor, +) if version.parse(diffusers.__version__) <= version.parse("0.20.0"): from diffusers.loaders import PatchedLoraProjection @@ -80,7 +83,9 @@ def get_delta_weight( weight: float, ): if weight == 0: - return torch.zeros_like(self.weight, dtype=self.weight.dtype, device=self.weight.device) + return torch.zeros_like( + self.weight, dtype=self.weight.dtype, device=self.weight.device + ) if isinstance(self, (torch.nn.Linear, PatchedLoraProjection)): lora_weight = torch.bmm(w_up[None, :], w_down[None, :])[0] @@ -116,7 +121,9 @@ def _set_adapter(self, adapter_names, adapter_weights): if adapter_weights is None: adapter_weights = 1.0 if isinstance(adapter_weights, float): - adapter_weights = [adapter_weights,] * len(adapter_names) + adapter_weights = [ + adapter_weights, + ] * len(adapter_names) _unfuse_lora(self) dtype, device = self.weight.data.dtype, self.weight.data.device @@ -131,13 +138,9 @@ def _set_adapter(self, adapter_names, adapter_weights): w_down = self.lora_A[adapter].float().to(device) w_up = self.lora_B[adapter].float().to(device) if delta_weight is None: - delta_weight = get_delta_weight( - self, w_up, w_down, self.scaling[adapter] - ) + delta_weight = get_delta_weight(self, w_up, w_down, self.scaling[adapter]) else: - delta_weight += get_delta_weight( - self, w_up, w_down, self.scaling[adapter] - ) + delta_weight += get_delta_weight(self, w_up, w_down, self.scaling[adapter]) if delta_weight is not None: fused_weight = self.weight.data.float() + delta_weight diff --git a/onediff_diffusers_extensions/onediffx/utils/patch_image_processor.py b/onediff_diffusers_extensions/onediffx/utils/patch_image_processor.py index 55f0078e7..ea8d83daf 100644 --- a/onediff_diffusers_extensions/onediffx/utils/patch_image_processor.py +++ b/onediff_diffusers_extensions/onediffx/utils/patch_image_processor.py @@ -1,11 +1,12 @@ import warnings from typing import List, Optional -import torch + import numpy as np import PIL.Image -from PIL import Image +import torch from diffusers.image_processor import VaeImageProcessor from diffusers.utils import deprecate +from PIL import Image def patch_image_prcessor(processor): diff --git a/onediff_diffusers_extensions/tests/profile_lora.py b/onediff_diffusers_extensions/tests/profile_lora.py index 1ecdc3535..2d318c820 100644 --- a/onediff_diffusers_extensions/tests/profile_lora.py +++ b/onediff_diffusers_extensions/tests/profile_lora.py @@ -1,9 +1,10 @@ import time from pathlib import Path -import torch import pandas as pd import safetensors.torch + +import torch from diffusers import DiffusionPipeline from onediff.infer_compiler import oneflow_compile @@ -74,7 +75,9 @@ def __exit__(self, exc_type, exc_value, traceback): load_and_fuse_lora_time = [] for i, (name, lora) in enumerate(loras.items()): with TimerContextManager("load_and_fuse_lora", Path(name).stem): - load_and_fuse_lora(pipe, lora.copy(), adapter_name=Path(name).stem, lora_scale=1.0) + load_and_fuse_lora( + pipe, lora.copy(), adapter_name=Path(name).stem, lora_scale=1.0 + ) unfuse_lora(pipe) load_and_fuse_lora_time.append(_time) diff --git a/onediff_diffusers_extensions/tests/profile_multi_lora.py b/onediff_diffusers_extensions/tests/profile_multi_lora.py index e50b6a750..7667a5fe8 100644 --- a/onediff_diffusers_extensions/tests/profile_multi_lora.py +++ b/onediff_diffusers_extensions/tests/profile_multi_lora.py @@ -2,17 +2,20 @@ from pathlib import Path import pandas as pd -import torch import safetensors.torch +import torch from diffusers import DiffusionPipeline from diffusers.utils.constants import USE_PEFT_BACKEND from onediff.infer_compiler import oneflow_compile from onediff.torch_utils import TensorInplaceAssign -from onediffx.lora import load_and_fuse_lora, unfuse_lora, set_and_fuse_adapters +from onediffx.lora import load_and_fuse_lora, set_and_fuse_adapters, unfuse_lora if not USE_PEFT_BACKEND: - raise RuntimeError("The profile if for PEFT APIs, please make sure you have installed peft>=0.6.0 and transformers >= 4.34.0") + raise RuntimeError( + "The profile if for PEFT APIs, please make sure you have installed peft>=0.6.0 and transformers >= 4.34.0" + ) + class TimerContextManager: def __init__(self, msg, lora): @@ -52,7 +55,13 @@ def __exit__(self, exc_type, exc_value, traceback): # for OneDiffX APIs for i, (name, lora) in enumerate(loras.items()): - load_and_fuse_lora(pipe, lora.copy(), adapter_name=Path(name).stem, lora_scale=1.0, offload_device="cuda") + load_and_fuse_lora( + pipe, + lora.copy(), + adapter_name=Path(name).stem, + lora_scale=1.0, + offload_device="cuda", + ) unfuse_lora(pipe) multi_loras = [] @@ -96,12 +105,10 @@ def __exit__(self, exc_type, exc_value, traceback): data = { "LoRA names": lora_names, "PEFT set_adapter": [f"{x:.2f} s" for x in peft_set_adapter_time], - "OneDiffX set_adapter": [ - f"{x:.2f} s" for x in set_adapter_time - ], + "OneDiffX set_adapter": [f"{x:.2f} s" for x in set_adapter_time], } df = pd.DataFrame(data) print(df) with open("result.md", "w") as file: - file.write(df.to_markdown(index=False)) \ No newline at end of file + file.write(df.to_markdown(index=False)) diff --git a/onediff_diffusers_extensions/tests/test_lora.py b/onediff_diffusers_extensions/tests/test_lora.py index 072ee49eb..fb574dfb5 100644 --- a/onediff_diffusers_extensions/tests/test_lora.py +++ b/onediff_diffusers_extensions/tests/test_lora.py @@ -1,28 +1,44 @@ -import pytest import random -from PIL import Image from collections import defaultdict -from typing import Dict, List, Tuple from pathlib import Path +from typing import Dict, List, Tuple -import torch -from torch import Tensor import numpy as np +import pytest import safetensors.torch -from skimage.metrics import structural_similarity + +import torch from diffusers import DiffusionPipeline from onediff.infer_compiler import oneflow_compile -from onediffx.lora import load_and_fuse_lora, unfuse_lora, set_and_fuse_adapters, get_active_adapters, delete_adapters +from onediffx.lora import ( + delete_adapters, + get_active_adapters, + load_and_fuse_lora, + set_and_fuse_adapters, + unfuse_lora, +) +from PIL import Image +from skimage.metrics import structural_similarity +from torch import Tensor HEIGHT = 1024 WIDTH = 1024 NUM_STEPS = 30 LORA_SCALE = 0.5 -LATENTS = torch.randn(1, 4, 128, 128, generator=torch.cuda.manual_seed(0), dtype=torch.float16, device="cuda") +LATENTS = torch.randn( + 1, + 4, + 128, + 128, + generator=torch.cuda.manual_seed(0), + dtype=torch.float16, + device="cuda", +) image_file_prefix = "/share_nfs/onediff_ci/diffusers/images/1.0" + @pytest.fixture def prepare_loras() -> Dict[str, Dict[str, Tensor]]: loras = [ @@ -35,12 +51,15 @@ def prepare_loras() -> Dict[str, Dict[str, Tensor]]: loras = {x: safetensors.torch.load_file(x) for x in loras} return loras + @pytest.fixture def get_loras(prepare_loras) -> Dict[str, Dict[str, Tensor]]: def _get_loras(): return {name: lora_dict.copy() for name, lora_dict in prepare_loras.items()} + return _get_loras + @pytest.fixture def get_multi_loras(prepare_loras) -> Dict[str, Dict[str, Tensor]]: def _get_multi_loras(): @@ -52,6 +71,7 @@ def _get_multi_loras(): current_lora.append(lora_dict) multi_lora[tuple(current_name)] = current_lora return multi_lora + return _get_multi_loras @@ -63,6 +83,7 @@ def pipe(): ).to("cuda") return pipeline + def generate_image(pipe): image = pipe( "masterpiece, best quality, mountain", @@ -74,9 +95,11 @@ def generate_image(pipe): ).images[0] return image + def prepare_target_images(pipe, loras): target_images_list = [ - f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png" for name in loras + f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png" + for name in loras ] if all(Path(x).exists() for x in target_images_list): return @@ -88,12 +111,16 @@ def prepare_target_images(pipe, loras): image = generate_image(pipe) pipe.unfuse_lora() pipe.unload_lora_weights() - image.save(f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png") + image.save( + f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png" + ) torch.cuda.empty_cache() + def prepare_target_images_multi_lora(pipe, loras, multi_loras): target_images_list = [ - f"{image_file_prefix}/test_sdxl_multi_lora_{'_'.join([str(Path(name).stem) for name in names])}_{HEIGHT}_{WIDTH}.png" for names in multi_loras + f"{image_file_prefix}/test_sdxl_multi_lora_{'_'.join([str(Path(name).stem) for name in names])}_{HEIGHT}_{WIDTH}.png" + for names in multi_loras ] if all(Path(x).exists() for x in target_images_list): return @@ -105,9 +132,15 @@ def prepare_target_images_multi_lora(pipe, loras, multi_loras): print("Didn't find target images, try to generate...") for names, loras in multi_loras.items(): names = [str(Path(name).stem) for name in names] - pipe.set_adapters(names, [LORA_SCALE, ] * len(names)) + pipe.set_adapters( + names, + [ + LORA_SCALE, + ] + * len(names), + ) image = generate_image(pipe) - image_name = f"{image_file_prefix}/test_sdxl_multi_lora_{'_'.join([str(Path(name).stem) for name in names])}_{HEIGHT}_{WIDTH}.png" + image_name = f"{image_file_prefix}/test_sdxl_multi_lora_{'_'.join([str(Path(name).stem) for name in names])}_{HEIGHT}_{WIDTH}.png" image.save(image_name) pipe.unload_lora_weights() torch.cuda.empty_cache() @@ -116,14 +149,21 @@ def prepare_target_images_multi_lora(pipe, loras, multi_loras): def preload_multi_loras(pipe, loras): for name, lora in loras.items(): load_and_fuse_lora( - pipe, lora.copy(), adapter_name=Path(name).stem, + pipe, + lora.copy(), + adapter_name=Path(name).stem, ) unfuse_lora(pipe) def test_lora_loading(pipe, get_loras): pipe.unet = oneflow_compile(pipe.unet) - pipe("a cat", height=HEIGHT, width=WIDTH, num_inference_steps=NUM_STEPS,).images[0] + pipe( + "a cat", + height=HEIGHT, + width=WIDTH, + num_inference_steps=NUM_STEPS, + ).images[0] loras = get_loras() prepare_target_images(pipe, loras) @@ -131,14 +171,18 @@ def test_lora_loading(pipe, get_loras): load_and_fuse_lora(pipe, lora.copy()) images_fusion = generate_image(pipe) target_image = np.array( - Image.open(f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png") + Image.open( + f"{image_file_prefix}/test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png" + ) ) curr_image = np.array(images_fusion) ssim = structural_similarity( curr_image, target_image, channel_axis=-1, data_range=255 ) unfuse_lora(pipe) - images_fusion.save(f"./test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png") + images_fusion.save( + f"./test_sdxl_lora_{str(Path(name).stem)}_{HEIGHT}_{WIDTH}.png" + ) print(f"lora {name} ssim {ssim}") assert ssim > 0.92, f"LoRA {name} ssim too low" @@ -152,12 +196,21 @@ def test_multi_lora_loading(pipe, get_multi_loras, get_loras): for names, loras in multi_loras.items(): names = [str(Path(name).stem) for name in names] - set_and_fuse_adapters(pipe, names, [LORA_SCALE, ] * len(names)) + set_and_fuse_adapters( + pipe, + names, + [ + LORA_SCALE, + ] + * len(names), + ) images_fusion = generate_image(pipe) - image_name = '_'.join([str(Path(name).stem) for name in names]) + image_name = "_".join([str(Path(name).stem) for name in names]) target_image = np.array( - Image.open(f"{image_file_prefix}/test_sdxl_multi_lora_{image_name}_{HEIGHT}_{WIDTH}.png") + Image.open( + f"{image_file_prefix}/test_sdxl_multi_lora_{image_name}_{HEIGHT}_{WIDTH}.png" + ) ) images_fusion.save(f"./test_sdxl_multi_lora_{image_name}_{HEIGHT}_{WIDTH}.png") curr_image = np.array(images_fusion) @@ -186,5 +239,7 @@ def test_delete_adapters(pipe, get_multi_loras): set_and_fuse_adapters(pipe, names) delete_adapters(pipe, names_to_delete) active_adapters = get_active_adapters(pipe) - print(f"current adapters: {active_adapters}, target adapters: {list(set(names) - set(names_to_delete))}") + print( + f"current adapters: {active_adapters}, target adapters: {list(set(names) - set(names_to_delete))}" + ) assert set(active_adapters) == set(names) - set(names_to_delete) diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py index 61e920bc8..f4d332a55 100644 --- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py +++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py @@ -1,19 +1,20 @@ import argparse import time -import torch -from PIL import Image from pathlib import Path +import torch + from diffusers import ( - AutoPipelineForText2Image, AutoPipelineForImage2Image, - StableDiffusionXLPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionPipeline, + AutoPipelineForText2Image, StableDiffusionImg2ImgPipeline, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, ) from onediff.quantization import QuantPipeline +from PIL import Image parser = argparse.ArgumentParser() @@ -64,12 +65,22 @@ parser.add_argument("--cache_dir", type=str, default=None) args = parser.parse_args() -pipeline_cls = AutoPipelineForText2Image if args.input_image is None else AutoPipelineForImage2Image -is_safetensors_model = Path(args.model).is_file and Path(args.model).suffix == ".safetensors" +pipeline_cls = ( + AutoPipelineForText2Image + if args.input_image is None + else AutoPipelineForImage2Image +) +is_safetensors_model = ( + Path(args.model).is_file and Path(args.model).suffix == ".safetensors" +) if is_safetensors_model: - try: # check if safetensors is SDXL - pipeline_cls = StableDiffusionXLPipeline if args.input_image is None else StableDiffusionXLImg2ImgPipeline + try: # check if safetensors is SDXL + pipeline_cls = ( + StableDiffusionXLPipeline + if args.input_image is None + else StableDiffusionXLImg2ImgPipeline + ) pipe = QuantPipeline.from_single_file( pipeline_cls, args.model, @@ -78,7 +89,11 @@ use_safetensors=True, ) except: - pipeline_cls = StableDiffusionPipeline if args.input_image is None else StableDiffusionImg2ImgPipeline + pipeline_cls = ( + StableDiffusionPipeline + if args.input_image is None + else StableDiffusionImg2ImgPipeline + ) pipe = QuantPipeline.from_single_file( pipeline_cls, args.model, diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-svd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-svd-fast.py index 36f073773..6e29fdd0b 100644 --- a/onediff_diffusers_extensions/tools/quantization/quantize-svd-fast.py +++ b/onediff_diffusers_extensions/tools/quantization/quantize-svd-fast.py @@ -1,11 +1,12 @@ import argparse import time + import torch -from PIL import Image from diffusers import StableVideoDiffusionPipeline from diffusers.utils import load_image from onediff.quantization import QuantPipeline +from PIL import Image parser = argparse.ArgumentParser() diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index 573b77dff..f2a60e873 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -55,7 +55,7 @@ Accessing http://server:7860/ from a web browser. ## Extensions Usage -To activate OneDiff extension acceleration, follow these steps: +To activate OneDiff extension acceleration, follow these steps: Select `onediff_diffusion_model` from the Script menu, enter a prompt in the text box (e.g., "a black dog"), and then click the "Generate" button. ![onediff_script](images/onediff_script.jpg) diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py index 60827fd87..99c7ba3d1 100644 --- a/onediff_sd_webui_extensions/compile/__init__.py +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -2,11 +2,11 @@ from .compile import get_compiled_graph from .sd2 import SD21CompileCtx from .utils import ( - OneDiffCompiledGraph, get_onediff_backend, init_backend, is_nexfort_backend, is_oneflow_backend, + OneDiffCompiledGraph, ) from .vae import VaeCompileCtx diff --git a/onediff_sd_webui_extensions/compile/compile.py b/onediff_sd_webui_extensions/compile/compile.py index 22a4d8628..c41ccaa40 100644 --- a/onediff_sd_webui_extensions/compile/compile.py +++ b/onediff_sd_webui_extensions/compile/compile.py @@ -1,13 +1,14 @@ -from compile import OneDiffBackend from modules.sd_hijack import apply_optimizations from onediff.infer_compiler import compile, oneflow_compile +from compile import OneDiffBackend + from .utils import ( - OneDiffCompiledGraph, disable_unet_checkpointing, is_nexfort_backend, is_oneflow_backend, + OneDiffCompiledGraph, ) diff --git a/onediff_sd_webui_extensions/compile/nexfort/utils.py b/onediff_sd_webui_extensions/compile/nexfort/utils.py index b25a91313..50f330afe 100644 --- a/onediff_sd_webui_extensions/compile/nexfort/utils.py +++ b/onediff_sd_webui_extensions/compile/nexfort/utils.py @@ -9,9 +9,9 @@ from modules.hypernetworks import hypernetwork from modules.sd_hijack_optimizations import SdOptimization from modules.sd_hijack_utils import CondFunc -from onediff_utils import singleton_decorator from onediff.utils.import_utils import is_nexfort_available +from onediff_utils import singleton_decorator @singleton_decorator diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/common.py b/onediff_sd_webui_extensions/compile/oneflow/mock/common.py index c77f5c3d1..7e0db1341 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/common.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/common.py @@ -1,7 +1,7 @@ import math from inspect import isfunction -import oneflow as flow +import oneflow as flow # usort: skip from oneflow import nn @@ -58,7 +58,7 @@ def __init__( inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py index ed4a934bc..c5657bf3c 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/controlnet.py @@ -1,4 +1,4 @@ -import oneflow as flow +import oneflow as flow # usort: skip from compile.oneflow.mock.common import timestep_embedding from ldm.modules.diffusionmodules.openaimodel import UNetModel from modules import devices @@ -30,6 +30,7 @@ def aligned_adding(base, x, require_channel_alignment): x = flow.nn.functional.interpolate(x, size=(base_h, base_w), mode="nearest") return base + x + def cat(tensors, *args, **kwargs): if len(tensors) == 2: a, b = tensors diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py index 8e9295ca5..1dde9e23c 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/ldm.py @@ -1,4 +1,4 @@ -import oneflow as flow +import oneflow as flow # usort: skip from ldm.modules.attention import CrossAttention, SpatialTransformer from ldm.modules.diffusionmodules.openaimodel import UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 diff --git a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py index a071bd7e5..c25a390b0 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py +++ b/onediff_sd_webui_extensions/compile/oneflow/mock/sgm.py @@ -1,10 +1,9 @@ -import oneflow as flow +import oneflow as flow # usort: skip +from onediff.infer_compiler.backends.oneflow.transform import proxy_class from sgm.modules.attention import CrossAttention, SpatialTransformer from sgm.modules.diffusionmodules.openaimodel import UNetModel from sgm.modules.diffusionmodules.util import GroupNorm32 -from onediff.infer_compiler.backends.oneflow.transform import proxy_class - from .common import CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding diff --git a/onediff_sd_webui_extensions/compile/oneflow/utils.py b/onediff_sd_webui_extensions/compile/oneflow/utils.py index 006dfd894..05f6217d7 100644 --- a/onediff_sd_webui_extensions/compile/oneflow/utils.py +++ b/onediff_sd_webui_extensions/compile/oneflow/utils.py @@ -1,7 +1,6 @@ -from onediff_utils import singleton_decorator - from onediff.infer_compiler.backends.oneflow.transform import register from onediff.utils.import_utils import is_oneflow_available +from onediff_utils import singleton_decorator @singleton_decorator diff --git a/onediff_sd_webui_extensions/compile/utils.py b/onediff_sd_webui_extensions/compile/utils.py index abbc44391..690b4b23b 100644 --- a/onediff_sd_webui_extensions/compile/utils.py +++ b/onediff_sd_webui_extensions/compile/utils.py @@ -4,9 +4,9 @@ import torch from ldm.modules.diffusionmodules.openaimodel import UNetModel as LdmUNetModel from modules import sd_models_types, shared -from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from onediff.infer_compiler import DeployableModule +from sgm.modules.diffusionmodules.openaimodel import UNetModel as SgmUNetModel from .backend import OneDiffBackend diff --git a/onediff_sd_webui_extensions/compile/vae.py b/onediff_sd_webui_extensions/compile/vae.py index 172578501..1166f2d94 100644 --- a/onediff_sd_webui_extensions/compile/vae.py +++ b/onediff_sd_webui_extensions/compile/vae.py @@ -1,10 +1,10 @@ -from compile.utils import get_onediff_backend from modules import shared -from modules.sd_vae_approx import model as get_vae_model -from modules.sd_vae_approx import sd_vae_approx_models +from modules.sd_vae_approx import model as get_vae_model, sd_vae_approx_models from onediff.infer_compiler import compile +from compile.utils import get_onediff_backend + __all__ = ["VaeCompileCtx"] compiled_models = {} diff --git a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py index 6f7df7871..fa5ecfa06 100644 --- a/onediff_sd_webui_extensions/onediff_controlnet/hijack.py +++ b/onediff_sd_webui_extensions/onediff_controlnet/hijack.py @@ -473,7 +473,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # sdxl's attention hacking is highly unstable. # We have no other methods but to reduce the style_fidelity a bit. # By default, 0.5 ** 3.0 = 0.125 - outer.current_style_fidelity = outer.current_style_fidelity ** 3.0 + outer.current_style_fidelity = outer.current_style_fidelity**3.0 if param.cfg_injection: outer.current_style_fidelity = 1.0 diff --git a/onediff_sd_webui_extensions/onediff_utils.py b/onediff_sd_webui_extensions/onediff_utils.py index 4e5172dcf..33f33348c 100644 --- a/onediff_sd_webui_extensions/onediff_utils.py +++ b/onediff_sd_webui_extensions/onediff_utils.py @@ -11,7 +11,7 @@ from onediff.utils.import_utils import is_oneflow_available if is_oneflow_available(): - import oneflow as flow + import oneflow as flow # usort: skip from compile import init_backend, is_oneflow_backend from modules import shared diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 06dd0fc11..7afb3a444 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,3 +1,6 @@ +"""oneflow_compiled UNetModel""" + + from pathlib import Path import gradio as gr @@ -7,16 +10,18 @@ import onediff_controlnet import onediff_shared from compile import ( + get_compiled_graph, + get_onediff_backend, OneDiffBackend, SD21CompileCtx, VaeCompileCtx, - get_compiled_graph, - get_onediff_backend, ) from compile.nexfort.utils import add_nexfort_optimizer from modules import script_callbacks from modules.processing import process_images from modules.ui_common import create_refresh_button + +from onediff.utils import logger, parse_boolean_from_env from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate @@ -33,10 +38,6 @@ varify_can_use_quantization, ) -from onediff.utils import logger, parse_boolean_from_env - -"""oneflow_compiled UNetModel""" - class UnetCompileCtx(object): """The unet model is stored in a global variable. @@ -149,7 +150,9 @@ def run( if need_recompile: if not onediff_shared.controlnet_enabled: onediff_shared.current_unet_graph = get_compiled_graph( - shared.sd_model, quantization=quantization, backend=backend, + shared.sd_model, + quantization=quantization, + backend=backend, ) load_graph(onediff_shared.current_unet_graph, compiler_cache) else: diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 555e8f938..000000000 --- a/pyproject.toml +++ /dev/null @@ -1,6 +0,0 @@ -# pyproject.toml - -# pip install black==19.10b0 -# black --config pyproject.toml file -[tool.black] -line-length = 88 diff --git a/setup.py b/setup.py index cf3ab2356..841341c3b 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ def get_version(): url="https://github.com/siliconflow/onediff", author="OneDiff contributors", license="Apache-2.0", - license_files=('LICENSE',), + license_files=("LICENSE",), author_email="contact@siliconflow.com", package_dir={"": "src"}, packages=find_packages("src"), @@ -43,4 +43,11 @@ def get_version(): ], long_description=open("README.md").read(), long_description_content_type="text/markdown", + extras_require={ + # optional dependencies, required by some features + # dev dependencies. Install them by `pip3 install 'onediff[dev]'` + "dev": [ + "pre-commit", + ], + }, ) diff --git a/src/infer_compiler_registry/register_diffusers/__init__.py b/src/infer_compiler_registry/register_diffusers/__init__.py index 292ffcdaf..05d2cae19 100644 --- a/src/infer_compiler_registry/register_diffusers/__init__.py +++ b/src/infer_compiler_registry/register_diffusers/__init__.py @@ -1,7 +1,8 @@ +import importlib.metadata + from onediff.infer_compiler.backends.oneflow.transform import register from packaging import version -import importlib.metadata diffusers_version = version.parse(importlib.metadata.version("diffusers")) @@ -10,41 +11,41 @@ Attention, AttnProcessor, AttnProcessor2_0, + LoRAAttnProcessor2_0, ) -from diffusers.models.attention_processor import LoRAAttnProcessor2_0 if diffusers_version < version.parse("0.26.00"): - from diffusers.models.unet_2d_condition import UNet2DConditionModel + from diffusers.models.transformer_2d import Transformer2DModel from diffusers.models.unet_2d_blocks import ( AttnUpBlock2D, CrossAttnUpBlock2D, UpBlock2D, ) - from diffusers.models.transformer_2d import Transformer2DModel + from diffusers.models.unet_2d_condition import UNet2DConditionModel else: - from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + from diffusers.models.transformers.transformer_2d import Transformer2DModel from diffusers.models.unets.unet_2d_blocks import ( AttnUpBlock2D, CrossAttnUpBlock2D, UpBlock2D, ) - from diffusers.models.transformers.transformer_2d import Transformer2DModel + from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel if diffusers_version >= version.parse("0.25.00"): from diffusers.models.upsampling import Upsample2D else: from diffusers.models.resnet import Upsample2D if diffusers_version >= version.parse("0.24.00"): - from diffusers.models.resnet import SpatioTemporalResBlock from diffusers.models.attention import TemporalBasicTransformerBlock + from diffusers.models.resnet import SpatioTemporalResBlock if diffusers_version >= version.parse("0.26.00"): - from diffusers.models.unets.unet_spatio_temporal_condition import ( - UNetSpatioTemporalConditionModel, - ) from diffusers.models.transformers.transformer_temporal import ( TransformerSpatioTemporalModel, ) + from diffusers.models.unets.unet_spatio_temporal_condition import ( + UNetSpatioTemporalConditionModel, + ) else: from diffusers.models.transformer_temporal import TransformerSpatioTemporalModel from diffusers.models.unet_spatio_temporal_condition import ( @@ -60,27 +61,25 @@ from .spatio_temporal_oflow import ( SpatioTemporalResBlock as SpatioTemporalResBlockOflow, - ) - from .spatio_temporal_oflow import TemporalDecoder as TemporalDecoderOflow - from .spatio_temporal_oflow import ( - TransformerSpatioTemporalModel as TransformerSpatioTemporalModelOflow, - ) - from .spatio_temporal_oflow import ( TemporalBasicTransformerBlock as TemporalBasicTransformerBlockOflow, - ) - from .spatio_temporal_oflow import ( + TemporalDecoder as TemporalDecoderOflow, + TransformerSpatioTemporalModel as TransformerSpatioTemporalModelOflow, UNetSpatioTemporalConditionModel as UNetSpatioTemporalConditionModelOflow, ) -from .attention_processor_oflow import Attention as AttentionOflow -from .attention_processor_oflow import AttnProcessor as AttnProcessorOflow -from .attention_processor_oflow import LoRAAttnProcessor2_0 as LoRAAttnProcessorOflow -from .unet_2d_condition_oflow import UNet2DConditionModel as UNet2DConditionModelOflow -from .unet_2d_blocks_oflow import AttnUpBlock2D as AttnUpBlock2DOflow -from .unet_2d_blocks_oflow import CrossAttnUpBlock2D as CrossAttnUpBlock2DOflow -from .unet_2d_blocks_oflow import UpBlock2D as UpBlock2DOflow +from .attention_processor_oflow import ( + Attention as AttentionOflow, + AttnProcessor as AttnProcessorOflow, + LoRAAttnProcessor2_0 as LoRAAttnProcessorOflow, +) from .resnet_oflow import Upsample2D as Upsample2DOflow from .transformer_2d_oflow import Transformer2DModel as Transformer2DModelOflow +from .unet_2d_blocks_oflow import ( + AttnUpBlock2D as AttnUpBlock2DOflow, + CrossAttnUpBlock2D as CrossAttnUpBlock2DOflow, + UpBlock2D as UpBlock2DOflow, +) +from .unet_2d_condition_oflow import UNet2DConditionModel as UNet2DConditionModelOflow # For CI if diffusers_version >= version.parse("0.24.00"): diff --git a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py index 68f3a1c2e..70684b16c 100644 --- a/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/attention_processor_oflow.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Callable, Optional, Union +import diffusers + import oneflow as torch import oneflow.nn.functional as F -from oneflow import nn -import os - -import diffusers from diffusers.utils import deprecate, logging from onediff.utils import parse_boolean_from_env, set_boolean_env_var +from oneflow import nn def is_xformers_available(): @@ -101,7 +101,7 @@ def __init__( self._from_deprecated_attn_block = _from_deprecated_attn_block self.scale_qk = scale_qk - self.scale = dim_head ** -0.5 if self.scale_qk else 1.0 + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -193,7 +193,8 @@ def set_use_memory_efficient_attention_xformers( attention_op: Optional[Callable] = None, ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, LORA_ATTENTION_PROCESSORS, + self.processor, + LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, @@ -414,7 +415,11 @@ def get_attention_scores(self, query, key, attention_mask=None): beta = 1 attention_scores = torch.baddbmm( - baddbmm_input, query, key.transpose(-1, -2), beta=beta, alpha=self.scale, + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, ) del baddbmm_input @@ -2077,7 +2082,9 @@ class SpatialNorm(nn.Module): """ def __init__( - self, f_channels, zq_channels, + self, + f_channels, + zq_channels, ): super().__init__() self.norm_layer = nn.GroupNorm( diff --git a/src/infer_compiler_registry/register_diffusers/resnet_oflow.py b/src/infer_compiler_registry/register_diffusers/resnet_oflow.py index 3133cabab..09a2129e3 100644 --- a/src/infer_compiler_registry/register_diffusers/resnet_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/resnet_oflow.py @@ -1,11 +1,12 @@ +import importlib.metadata from typing import Optional + import oneflow as torch import oneflow.nn as nn import oneflow.nn.functional as F -from packaging import version -import importlib.metadata from onediff.infer_compiler.backends.oneflow.transform import transform_mgr +from packaging import version transformed_diffusers = transform_mgr.transform_package("diffusers") @@ -66,7 +67,6 @@ def forward( return hidden_states - else: class Upsample2D(transformed_diffusers.models.resnet.Upsample2D): diff --git a/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py b/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py index fd4aacb54..e60a0c04a 100644 --- a/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/spatio_temporal_oflow.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools +import importlib.metadata +from collections import OrderedDict +from typing import Any, Dict, Optional, Tuple, Union + import oneflow as torch import oneflow.nn.functional as F from oneflow import nn -import functools -from typing import Any, Optional, Tuple, Union, Dict -from collections import OrderedDict +from packaging import version from .attention_processor_oflow import AttentionProcessor -from packaging import version -import importlib.metadata - diffusers_version = version.parse(importlib.metadata.version("diffusers")) diffusers_0240_v = version.parse("0.24.0") @@ -290,7 +290,8 @@ def forward( ) else: hidden_states = block( - hidden_states, encoder_hidden_states=encoder_hidden_states, + hidden_states, + encoder_hidden_states=encoder_hidden_states, ) hidden_states_mix = hidden_states diff --git a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py index 2c3b2298f..07512fc12 100644 --- a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py @@ -1,12 +1,12 @@ +import importlib.metadata from dataclasses import dataclass from typing import Any, Dict, Optional -from packaging import version -import importlib.metadata import oneflow as torch import oneflow.nn.functional as F -from oneflow import nn from onediff.infer_compiler.backends.oneflow.transform import transform_mgr +from oneflow import nn +from packaging import version transformed_diffusers = transform_mgr.transform_package("diffusers") @@ -370,7 +370,6 @@ def forward( return Transformer2DModelOutput(sample=output) - elif diffusers_version < diffusers_02499_v: ConfigMixin = transformed_diffusers.configuration_utils.ConfigMixin register_to_config = transformed_diffusers.configuration_utils.register_to_config @@ -616,7 +615,7 @@ def __init__( inner_dim, elementwise_affine=False, eps=1e-6 ) self.scale_shift_table = nn.Parameter( - torch.randn(2, inner_dim) / inner_dim ** 0.5 + torch.randn(2, inner_dim) / inner_dim**0.5 ) self.proj_out = nn.Linear( inner_dim, patch_size * patch_size * self.out_channels @@ -886,7 +885,6 @@ def forward( return Transformer2DModelOutput(sample=output) - else: transformed_diffusers = transform_mgr.transform_package("diffusers") ConfigMixin = transformed_diffusers.configuration_utils.ConfigMixin @@ -1080,9 +1078,11 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py index 54ae20ae3..ffc8368f2 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -from packaging import version import importlib.metadata +from typing import Any, Dict, List, Optional, Tuple, Union + import oneflow as torch from onediff.infer_compiler.backends.oneflow.transform import transform_mgr +from packaging import version diffusers_0210_v = version.parse("0.21.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) @@ -68,11 +69,11 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if transformed_diffusers.utils.is_torch_version( - ">=", "1.11.0" - ) else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if transformed_diffusers.utils.is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, @@ -147,7 +148,6 @@ def custom_forward(*inputs): return hidden_states - else: class AttnUpBlock2D(transformed_diffusers.models.unet_2d_blocks.AttnUpBlock2D): @@ -236,11 +236,11 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = { - "use_reentrant": False - } if transformed_diffusers.utils.is_torch_version( - ">=", "1.11.0" - ) else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} + if transformed_diffusers.utils.is_torch_version(">=", "1.11.0") + else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py index 5fb16e84e..8154a3147 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -from packaging import version import importlib.metadata +from typing import Any, Dict, List, Optional, Tuple, Union + import oneflow as torch from onediff.infer_compiler.backends.oneflow.transform import transform_mgr +from packaging import version diffusers_0210_v = version.parse("0.21.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) @@ -97,7 +98,7 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - default_overall_up_factor = 2 ** self.num_upsamplers + default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` # forward_upsample_size = False diff --git a/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py b/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py index fb2028b40..b11e0997d 100644 --- a/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py +++ b/src/infer_compiler_registry/register_diffusers_enterprise_lite/__init__.py @@ -1,6 +1,6 @@ from onediff.infer_compiler.backends.oneflow.transform import register -import oneflow as flow +import oneflow as flow # usort: skip import diffusers_enterprise_lite torch2oflow_class_map = { diff --git a/src/infer_compiler_registry/register_onediff_quant/__init__.py b/src/infer_compiler_registry/register_onediff_quant/__init__.py index dd5a37a26..3f3fd9c1d 100644 --- a/src/infer_compiler_registry/register_onediff_quant/__init__.py +++ b/src/infer_compiler_registry/register_onediff_quant/__init__.py @@ -1,6 +1,6 @@ from onediff.infer_compiler.backends.oneflow.transform import register -import oneflow as flow +import oneflow as flow # usort: skip import onediff_quant torch2oflow_class_map = { diff --git a/src/onediff/infer_compiler/__init__.py b/src/onediff/infer_compiler/__init__.py index 7110e897e..cbe9bb890 100644 --- a/src/onediff/infer_compiler/__init__.py +++ b/src/onediff/infer_compiler/__init__.py @@ -1,4 +1,5 @@ import os + import torch from .backends import * diff --git a/src/onediff/infer_compiler/backends/__init__.py b/src/onediff/infer_compiler/backends/__init__.py index 1dd8ed60d..29bf3c775 100644 --- a/src/onediff/infer_compiler/backends/__init__.py +++ b/src/onediff/infer_compiler/backends/__init__.py @@ -1,4 +1,3 @@ +from .compiler import compile, oneflow_compile from .deployable_module import DeployableModule -from .compiler import compile -from .compiler import oneflow_compile from .env_var import OneflowCompileOptions diff --git a/src/onediff/infer_compiler/backends/compiler.py b/src/onediff/infer_compiler/backends/compiler.py index 2793df7c8..61b172fde 100644 --- a/src/onediff/infer_compiler/backends/compiler.py +++ b/src/onediff/infer_compiler/backends/compiler.py @@ -1,4 +1,5 @@ from typing import Callable, Optional + import torch from .deployable_module import DeployableModule diff --git a/src/onediff/infer_compiler/backends/deployable_module.py b/src/onediff/infer_compiler/backends/deployable_module.py index 51464df63..da483cc8a 100644 --- a/src/onediff/infer_compiler/backends/deployable_module.py +++ b/src/onediff/infer_compiler/backends/deployable_module.py @@ -1,4 +1,5 @@ from typing import Any + import torch diff --git a/src/onediff/infer_compiler/backends/env_var.py b/src/onediff/infer_compiler/backends/env_var.py index 9003f5432..e75983704 100644 --- a/src/onediff/infer_compiler/backends/env_var.py +++ b/src/onediff/infer_compiler/backends/env_var.py @@ -1,8 +1,9 @@ import dataclasses import os -import torch from typing import Optional +import torch + from onediff.utils import set_boolean_env_var, set_integer_env_var diff --git a/src/onediff/infer_compiler/backends/nexfort/README.md b/src/onediff/infer_compiler/backends/nexfort/README.md index d52deab97..aed8cd4e8 100644 --- a/src/onediff/infer_compiler/backends/nexfort/README.md +++ b/src/onediff/infer_compiler/backends/nexfort/README.md @@ -1,5 +1,5 @@ ## OneDiff Nexfort compiler backend(Beta Release) -OneDiff Nexfort is a lightweight [torch 2.0 compiler backend](https://pytorch.org/docs/stable/torch.compiler.html) strongly optimized for Diffusion Models. +OneDiff Nexfort is a lightweight [torch 2.0 compiler backend](https://pytorch.org/docs/stable/torch.compiler.html) strongly optimized for Diffusion Models. Currently, it is especially for DiT(Diffusion Transformer) models which is the backbone of [SD3](https://stability.ai/news/stable-diffusion-3) and [Sora](https://openai.com/sora/). diff --git a/src/onediff/infer_compiler/backends/nexfort/deployable_module.py b/src/onediff/infer_compiler/backends/nexfort/deployable_module.py index 8ff727036..fa9c07fdf 100644 --- a/src/onediff/infer_compiler/backends/nexfort/deployable_module.py +++ b/src/onediff/infer_compiler/backends/nexfort/deployable_module.py @@ -1,7 +1,9 @@ from types import FunctionType from typing import Type, Union + import torch from torch import nn + from ..deployable_module import DeployableModule diff --git a/src/onediff/infer_compiler/backends/oneflow/__init__.py b/src/onediff/infer_compiler/backends/oneflow/__init__.py index 4c7042454..d680370c2 100644 --- a/src/onediff/infer_compiler/backends/oneflow/__init__.py +++ b/src/onediff/infer_compiler/backends/oneflow/__init__.py @@ -1,3 +1,3 @@ +from ..env_var import OneflowCompileOptions from . import oneflow as _oneflow_backend from .deployable_module import OneflowDeployableModule -from ..env_var import OneflowCompileOptions diff --git a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py index efab5da01..246b1e50f 100644 --- a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py +++ b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py @@ -1,6 +1,7 @@ import torch -import oneflow as flow +import oneflow as flow # usort: skip from oneflow.framework.args_tree import ArgsTree + from onediff.utils import logger from .utils.hash_utils import generate_input_structure_key diff --git a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py index 7fc6adc15..57e38f633 100644 --- a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py @@ -1,29 +1,30 @@ import types -import torch from functools import wraps -import oneflow as flow +import torch + +import oneflow as flow # usort: skip from onediff.utils import logger from onediff.utils.chache_utils import LRUCache from ..deployable_module import DeployableModule - -from .transform.manager import transform_mgr -from .transform.builtin_transform import torch2oflow +from ..env_var import OneflowCompileOptions +from .args_tree_util import input_output_processor from .dual_module import DualModule, get_mixed_dual_module +from .graph_management_utils import graph_file_management from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled -from .args_tree_util import input_output_processor +from .online_quantization_utils import quantize_and_deploy_wrapper from .param_utils import ( - parse_device, check_device, generate_constant_folding_info, + parse_device, update_graph_with_constant_folding_info, ) -from .graph_management_utils import graph_file_management -from .online_quantization_utils import quantize_and_deploy_wrapper -from ..env_var import OneflowCompileOptions +from .transform.builtin_transform import torch2oflow + +from .transform.manager import transform_mgr @torch2oflow.register @@ -60,7 +61,11 @@ def get_oneflow_graph(model, size=9, dynamic_graph=True): class OneflowDeployableModule(DeployableModule): def __init__( - self, torch_module, oneflow_module, dynamic=True, options=None, + self, + torch_module, + oneflow_module, + dynamic=True, + options=None, ): torch.nn.Module.__init__(self) object.__setattr__( @@ -210,7 +215,7 @@ def extra_repr(self) -> str: return self._deployable_module_model.extra_repr() def set_graph_file(self, file_path: str) -> None: - """ Sets the path of the graph file. + """Sets the path of the graph file. If the new file path is different from the old one, clears old graph data. diff --git a/src/onediff/infer_compiler/backends/oneflow/dual_module.py b/src/onediff/infer_compiler/backends/oneflow/dual_module.py index df4696447..cd587943d 100644 --- a/src/onediff/infer_compiler/backends/oneflow/dual_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/dual_module.py @@ -1,15 +1,15 @@ import os import types -from typing import Any from itertools import chain +from typing import Any import torch -import oneflow as flow +import oneflow as flow # usort: skip from oneflow.utils.tensor import to_torch from onediff.utils import logger -from .transform.builtin_transform import torch2oflow from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled +from .transform.builtin_transform import torch2oflow class DualModule(torch.nn.Module): @@ -57,7 +57,10 @@ def _align_tensor(torch_module, oneflow_module): + [x for x, _ in oneflow_module.named_buffers()] ) for name, tensor in chain.from_iterable( - [torch_module.named_parameters(), torch_module.named_buffers(),] + [ + torch_module.named_parameters(), + torch_module.named_buffers(), + ] ): if name not in oneflow_tensor_list: tensor.data = tensor.to(*args, **kwargs) @@ -110,7 +113,7 @@ def __setattr__(self, name: str, value: Any) -> None: torch_obj = getattr(module, name) - if hasattr(torch_obj, 'copy_'): + if hasattr(torch_obj, "copy_"): torch_obj.copy_(value) else: setattr(module, name, value) diff --git a/src/onediff/infer_compiler/backends/oneflow/graph.py b/src/onediff/infer_compiler/backends/oneflow/graph.py index 301270832..adbbe848b 100644 --- a/src/onediff/infer_compiler/backends/oneflow/graph.py +++ b/src/onediff/infer_compiler/backends/oneflow/graph.py @@ -1,8 +1,8 @@ -import oneflow as flow +import oneflow as flow # usort: skip from onediff.utils import logger -from .transform.manager import transform_mgr from .transform.builtin_transform import reverse_proxy_class +from .transform.manager import transform_mgr from .utils.cost_util import cost_cnt diff --git a/src/onediff/infer_compiler/backends/oneflow/graph_management_utils.py b/src/onediff/infer_compiler/backends/oneflow/graph_management_utils.py index 3b2a2d888..2ce67034d 100644 --- a/src/onediff/infer_compiler/backends/oneflow/graph_management_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/graph_management_utils.py @@ -1,19 +1,21 @@ import importlib import os from typing import Dict + import torch -import oneflow as flow -from pathlib import Path +import oneflow as flow # usort: skip from functools import wraps +from pathlib import Path + from oneflow.framework.args_tree import ArgsTree + +from onediff.utils import logger +from ..env_var import OneflowCompileOptions from .transform.builtin_transform import torch2oflow from .transform.manager import transform_mgr from .utils.cost_util import cost_time -from ..env_var import OneflowCompileOptions from .utils.hash_utils import generate_input_structure_key, generate_model_structure_key -from onediff.utils import logger - def _prepare_file_path(file_path): if isinstance(file_path, Path): diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/__init__.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/__init__.py index 7f12b73a0..4f79c86f4 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/__init__.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/__init__.py @@ -1,3 +1,3 @@ """ Tools for importing modules and packages""" -from .importer import LazyMocker, DynamicModuleLoader from .import_module_utils import import_module_from_path +from .importer import DynamicModuleLoader, LazyMocker diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/dyn_mock_mod.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/dyn_mock_mod.py index 8ac3ae0c9..b10ae4735 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/dyn_mock_mod.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/dyn_mock_mod.py @@ -1,16 +1,18 @@ -from inspect import ismodule, signature -from types import ModuleType -from copy import deepcopy -from contextlib import contextmanager -from typing import List, Dict import importlib import inspect import os +from contextlib import contextmanager +from copy import deepcopy +from inspect import ismodule, signature +from types import ModuleType +from typing import Dict, List + import torch from oneflow.mock_torch import enable from oneflow.mock_torch.mock_importer import _importer -from .import_module_utils import import_module_from_path + from onediff.utils import logger +from .import_module_utils import import_module_from_path from .patch_for_compiler import * __all__ = ["DynamicMockModule"] @@ -116,7 +118,10 @@ def _update_module(full_names, main_pkg_enable_context): class DynamicMockModule(ModuleType): def __init__( - self, pkg_name: str, obj_entity: ModuleType, main_pkg_enable: callable, + self, + pkg_name: str, + obj_entity: ModuleType, + main_pkg_enable: callable, ): self._pkg_name = pkg_name self._obj_entity = obj_entity # ModuleType or _LazyModule diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/format_utils.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/format_utils.py index 3866aed82..d172ccb09 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/format_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/format_utils.py @@ -1,6 +1,6 @@ -from typing import Union -from types import FunctionType import inspect +from types import FunctionType +from typing import Union class MockEntityNameFormatter: diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/import_module_utils.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/import_module_utils.py index 22bff0929..0d3e362fb 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/import_module_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/import_module_utils.py @@ -1,9 +1,9 @@ -from pathlib import Path -from typing import Union -from types import ModuleType +import importlib import os import sys -import importlib +from pathlib import Path +from types import ModuleType +from typing import Union def import_module_from_path(module_path: Union[str, Path]) -> ModuleType: diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/importer.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/importer.py index 854a7577b..53e445e4a 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/importer.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/importer.py @@ -1,15 +1,16 @@ +import importlib import os import sys -import importlib -from inspect import ismodule -from typing import Optional, Union from functools import lru_cache -from types import FunctionType, ModuleType -from pathlib import Path from importlib.metadata import requires -from .format_utils import MockEntityNameFormatter -from .dyn_mock_mod import DynamicMockModule +from inspect import ismodule +from pathlib import Path +from types import FunctionType, ModuleType +from typing import Optional, Union + from onediff.utils import logger +from .dyn_mock_mod import DynamicMockModule +from .format_utils import MockEntityNameFormatter __all__ = ["LazyMocker", "is_need_mock"] diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py index 8e7f7e40b..8ba01f0cf 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py @@ -1,6 +1,7 @@ import math + import torch -import oneflow as flow +import oneflow as flow # usort: skip import oneflow.nn.functional as F diff --git a/src/onediff/infer_compiler/backends/oneflow/oneflow.py b/src/onediff/infer_compiler/backends/oneflow/oneflow.py index 0cf54914b..9b878c330 100644 --- a/src/onediff/infer_compiler/backends/oneflow/oneflow.py +++ b/src/onediff/infer_compiler/backends/oneflow/oneflow.py @@ -20,17 +20,17 @@ def compile(torch_module: torch.nn.Module, *, options=None): - 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, the compilation result is saved after the first run. - 'graph_file_device' (None) sets the device for the graph file, default None. If set, the compilation result will be converted to the specified device. """ - from .deployable_module import OneflowDeployableModule, get_mixed_deployable_module from ..env_var import ( + OneflowCompileOptions, set_oneflow_default_env_vars, set_oneflow_env_vars, - OneflowCompileOptions, ) + from .deployable_module import get_mixed_deployable_module, OneflowDeployableModule from .param_utils import ( - state_update_hook, - init_state_update_attr, - forward_pre_check_and_update_state_hook, forward_generate_constant_folding_info_hook, + forward_pre_check_and_update_state_hook, + init_state_update_attr, + state_update_hook, ) from .transform.custom_transform import set_default_registry diff --git a/src/onediff/infer_compiler/backends/oneflow/oneflow_exec_mode.py b/src/onediff/infer_compiler/backends/oneflow/oneflow_exec_mode.py index 13635c05f..6fff643d1 100644 --- a/src/onediff/infer_compiler/backends/oneflow/oneflow_exec_mode.py +++ b/src/onediff/infer_compiler/backends/oneflow/oneflow_exec_mode.py @@ -9,7 +9,7 @@ def __init__(self, enabled=None): self.enabled = True def __enter__(self): - import oneflow as flow + import oneflow as flow # usort: skip global _ONEFLOW_EXEC_MODE self.prev_mode = _ONEFLOW_EXEC_MODE @@ -18,7 +18,7 @@ def __enter__(self): _ = flow.set_grad_enabled(False) def __exit__(self, exc_type, exc_val, exc_tb): - import oneflow as flow + import oneflow as flow # usort: skip global _ONEFLOW_EXEC_MODE _ONEFLOW_EXEC_MODE = self.prev_mode diff --git a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py index 1a537dfc9..7b251730d 100644 --- a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py @@ -20,8 +20,8 @@ def online_quantize_model( """ from onediff_quant.quantization import ( - OnlineQuantModule, create_quantization_calculator, + OnlineQuantModule, ) if getattr(quant_config, "quantization_calculator", None): diff --git a/src/onediff/infer_compiler/backends/oneflow/param_utils.py b/src/onediff/infer_compiler/backends/oneflow/param_utils.py index f972686e7..40de2870a 100644 --- a/src/onediff/infer_compiler/backends/oneflow/param_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/param_utils.py @@ -1,8 +1,9 @@ import re import types + import torch -import oneflow as flow -from typing import List, Dict, Any, Union +import oneflow as flow # usort: skip +from typing import Any, Dict, List, Union from onediff.utils import logger diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/__init__.py b/src/onediff/infer_compiler/backends/oneflow/transform/__init__.py index 38ba22c4f..ec7d4736f 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/__init__.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/__init__.py @@ -1,11 +1,11 @@ """Module to convert PyTorch code to OneFlow.""" -from .manager import transform_mgr -from .builtin_transform import torch2oflow, default_converter -from .custom_transform import register - from .builtin_transform import ( - ProxySubmodule, - proxy_class, - map_args, + default_converter, get_attr, + map_args, + proxy_class, + ProxySubmodule, + torch2oflow, ) +from .custom_transform import register +from .manager import transform_mgr diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py b/src/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py index 83a2b9dd6..fee4b07dc 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py @@ -1,22 +1,23 @@ """Convert torch object to oneflow object.""" -import os import importlib -import types +import os import traceback -from functools import singledispatch, partial +import types from collections import OrderedDict from collections.abc import Iterable -from typing import Union, Any +from functools import partial, singledispatch +from typing import Any, Union + import torch -import oneflow as flow +import oneflow as flow # usort: skip -from .manager import transform_mgr from onediff.utils import logger -from .patch_for_diffusers import diffusers_checker from ..import_tools.importer import is_need_mock +from .manager import transform_mgr from .patch_for_comfy import PatchForComfy +from .patch_for_diffusers import diffusers_checker __all__ = [ "proxy_class", @@ -246,7 +247,7 @@ def proxy_getattr(self, attr): if verbose: logger.info( f""" - Warning: {type(of_mod)} is in training mode + Warning: {type(of_mod)} is in training mode and is turned into eval mode which is good for infrence optimation. """ ) diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py b/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py index 2ea176e82..b14b90ddb 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py @@ -1,13 +1,14 @@ """A module for registering custom torch2oflow functions and classes.""" -import inspect import importlib.util -from pathlib import Path +import inspect import sys +from pathlib import Path from typing import Callable, Dict, List, Optional, Union + +from onediff.utils import logger from ..import_tools import import_module_from_path -from .manager import transform_mgr from .builtin_transform import torch2oflow -from onediff.utils import logger +from .manager import transform_mgr __all__ = ["register"] diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/manager.py b/src/onediff/infer_compiler/backends/oneflow/transform/manager.py index dd572c11d..e3e8bbaa7 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/manager.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/manager.py @@ -1,10 +1,11 @@ import importlib +import logging import os import types import warnings -import logging -from typing import Dict, List, Union from pathlib import Path +from typing import Dict, List, Union + from onediff.utils import logger from ..import_tools.importer import LazyMocker diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/patch_for_diffusers.py b/src/onediff/infer_compiler/backends/oneflow/transform/patch_for_diffusers.py index e5cb43cbf..2303c650c 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/patch_for_diffusers.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/patch_for_diffusers.py @@ -1,5 +1,6 @@ # TODO: remove this file to diffusers/src/infer_compiler_registry/register_diffusers from abc import ABC, abstractmethod + from onediff.utils import logger try: diff --git a/src/onediff/infer_compiler/backends/oneflow/utils/cost_util.py b/src/onediff/infer_compiler/backends/oneflow/utils/cost_util.py index 4cb1575f5..c5fe58200 100644 --- a/src/onediff/infer_compiler/backends/oneflow/utils/cost_util.py +++ b/src/onediff/infer_compiler/backends/oneflow/utils/cost_util.py @@ -1,7 +1,8 @@ from functools import wraps -import oneflow as flow -import time +import oneflow as flow # usort: skip import inspect +import time + from onediff.utils import logger __all__ = ["cost_cnt", "cost_time"] diff --git a/src/onediff/infer_compiler/backends/oneflow/utils/hash_utils.py b/src/onediff/infer_compiler/backends/oneflow/utils/hash_utils.py index 6cab64b12..e1ab38d1a 100644 --- a/src/onediff/infer_compiler/backends/oneflow/utils/hash_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/utils/hash_utils.py @@ -1,4 +1,5 @@ import hashlib + from oneflow.framework.args_tree import ArgsTree diff --git a/src/onediff/infer_compiler/backends/oneflow/utils/version_util.py b/src/onediff/infer_compiler/backends/oneflow/utils/version_util.py index 5e0d22a8e..f1f135eef 100644 --- a/src/onediff/infer_compiler/backends/oneflow/utils/version_util.py +++ b/src/onediff/infer_compiler/backends/oneflow/utils/version_util.py @@ -1,4 +1,5 @@ from importlib_metadata import version + from onediff.utils import logger diff --git a/src/onediff/infer_compiler/backends/registry.py b/src/onediff/infer_compiler/backends/registry.py index bbf0e24bf..24abe554b 100644 --- a/src/onediff/infer_compiler/backends/registry.py +++ b/src/onediff/infer_compiler/backends/registry.py @@ -9,7 +9,8 @@ def register_backend( - name: Optional[str] = None, tags: Sequence[str] = (), + name: Optional[str] = None, + tags: Sequence[str] = (), ): def wrapper(compiler_fn: Optional[Any] = None): if compiler_fn is None: diff --git a/src/onediff/optimization/attention_processor.py b/src/onediff/optimization/attention_processor.py index c57dcc602..ae9f7d438 100644 --- a/src/onediff/optimization/attention_processor.py +++ b/src/onediff/optimization/attention_processor.py @@ -1,5 +1,5 @@ import os -import oneflow as flow +import oneflow as flow # usort: skip import oneflow.nn as nn import oneflow.nn.functional as F @@ -84,10 +84,7 @@ def __call__( hidden_states = flow.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) else: - from onediff.utils import ( - parse_boolean_from_env, - set_boolean_env_var, - ) + from onediff.utils import parse_boolean_from_env, set_boolean_env_var if attn.upcast_attention and parse_boolean_from_env( "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION", True diff --git a/src/onediff/optimization/quant_optimizer.py b/src/onediff/optimization/quant_optimizer.py index 9a00b883b..c127c9a82 100644 --- a/src/onediff/optimization/quant_optimizer.py +++ b/src/onediff/optimization/quant_optimizer.py @@ -1,14 +1,16 @@ import time +from copy import deepcopy + import torch import torch.nn as nn -from copy import deepcopy -from onediff.utils import logger + +from onediff.infer_compiler.backends.oneflow.transform.manager import transform_mgr +from onediff.infer_compiler.backends.oneflow.utils.cost_util import cost_cnt from onediff.infer_compiler.backends.oneflow.utils.version_util import ( is_quantization_enabled, ) -from onediff.infer_compiler.backends.oneflow.utils.cost_util import cost_cnt -from onediff.infer_compiler.backends.oneflow.transform.manager import transform_mgr from onediff.torch_utils.module_operations import modify_sub_module +from onediff.utils import logger __all__ = ["quantize_model", "varify_can_use_quantization"] @@ -36,9 +38,12 @@ def quantize_model( if varify_can_use_quantization() is False: return model - from onediff_quant.utils import symm_quantize_sub_module, find_quantizable_modules - from onediff_quant.utils import get_quantize_module from onediff_quant import Quantizer + from onediff_quant.utils import ( + find_quantizable_modules, + get_quantize_module, + symm_quantize_sub_module, + ) quantize_conv_cnt, quantize_linear_cnt = 0, 0 diff --git a/src/onediff/optimization/rewrite_self_attention.py b/src/onediff/optimization/rewrite_self_attention.py index efe6911ea..bc05d5c49 100644 --- a/src/onediff/optimization/rewrite_self_attention.py +++ b/src/onediff/optimization/rewrite_self_attention.py @@ -1,15 +1,20 @@ import os + import torch import torch.nn as nn -from diffusers.models.attention_processor import Attention -from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, +) + from .attention_processor import FusedSelfAttnProcessor _IS_ONEDIFF_QUANT_AVAILABLE = 0 try: import onediff_quant - from onediff_quant import StaticQuantLinearModule, DynamicQuantLinearModule + from onediff_quant import DynamicQuantLinearModule, StaticQuantLinearModule _IS_ONEDIFF_QUANT_AVAILABLE = 1 except ImportError as e: diff --git a/src/onediff/quantization/README.md b/src/onediff/quantization/README.md index 5f5181ed4..4012c0579 100644 --- a/src/onediff/quantization/README.md +++ b/src/onediff/quantization/README.md @@ -104,7 +104,7 @@ python ./src/onediff/quantization/quant_pipeline_test.py \ --linear_ssim_threshold 0.991 \ --save_as_float False \ --cache_dir "./run_sd-v1-5" \ - --quantized_model ./quantized_model + --quantized_model ./quantized_model ``` If you want to load a quantized model, you can modify the quantized_model parameter to the path of the specific model, such as the [sd-1.5-onediff-enterprise](https://huggingface.co/siliconflow/stable-diffusion-v1-5-onediff-comfy-enterprise-v1) and [sd-1.5-onediff-deepcache models](https://huggingface.co/siliconflow/stable-diffusion-v1-5-onediff-deepcache-int8). [Stable-diffusion-v2-1-onediff-enterprise](https://huggingface.co/siliconflow/stable-diffusion-v2-1-onediff-enterprise) it has not been quantified, so it needs to be quantified first. diff --git a/src/onediff/quantization/__init__.py b/src/onediff/quantization/__init__.py index f84c17836..19be6a917 100644 --- a/src/onediff/quantization/__init__.py +++ b/src/onediff/quantization/__init__.py @@ -1,2 +1,2 @@ -from .quantize_utils import setup_onediff_quant, load_calibration_and_quantize_pipeline from .quantize_pipeline import QuantPipeline +from .quantize_utils import load_calibration_and_quantize_pipeline, setup_onediff_quant diff --git a/src/onediff/quantization/load_quantized_model.py b/src/onediff/quantization/load_quantized_model.py index 913466137..4878652ec 100644 --- a/src/onediff/quantization/load_quantized_model.py +++ b/src/onediff/quantization/load_quantized_model.py @@ -1,8 +1,10 @@ -from diffusers import AutoPipelineForText2Image -from onediff.quantization.quantize_pipeline import QuantPipeline import argparse + import torch +from diffusers import AutoPipelineForText2Image + from onediff.infer_compiler import oneflow_compile +from onediff.quantization.quantize_pipeline import QuantPipeline def parse_args(): diff --git a/src/onediff/quantization/quant_pipeline_test.py b/src/onediff/quantization/quant_pipeline_test.py index a23efd134..3e9f0f954 100644 --- a/src/onediff/quantization/quant_pipeline_test.py +++ b/src/onediff/quantization/quant_pipeline_test.py @@ -1,7 +1,9 @@ +import argparse + +import torch from diffusers import AutoPipelineForText2Image + from onediff.quantization.quantize_pipeline import QuantPipeline -import torch -import argparse def parse_args(): diff --git a/src/onediff/quantization/quantize_pipeline.py b/src/onediff/quantization/quantize_pipeline.py index f3141a765..e4c851fb8 100644 --- a/src/onediff/quantization/quantize_pipeline.py +++ b/src/onediff/quantization/quantize_pipeline.py @@ -1,9 +1,10 @@ import os -from typing import Any, List, Optional, Union from functools import partial +from typing import Any, List, Optional, Union from onediff_quant import quantize_pipeline, save_quantized -from .quantize_utils import setup_onediff_quant, load_calibration_and_quantize_pipeline + +from .quantize_utils import load_calibration_and_quantize_pipeline, setup_onediff_quant class QuantPipeline: @@ -16,7 +17,7 @@ def from_quantized( **kwargs ): """load a quantized model. - + - Example: ```python from diffusers import AutoPipelineForText2Image @@ -49,7 +50,7 @@ def from_pretrained( **kwargs ): """load a floating model that to be quantized as int8. - + - Example: ```python from diffusers import AutoPipelineForText2Image diff --git a/src/onediff/torch_utils/model_inplace_assign.py b/src/onediff/torch_utils/model_inplace_assign.py index c8edc6a6d..3e444550b 100644 --- a/src/onediff/torch_utils/model_inplace_assign.py +++ b/src/onediff/torch_utils/model_inplace_assign.py @@ -1,7 +1,9 @@ import warnings -from typing import Union, List from collections import defaultdict +from typing import List, Union + import torch + from onediff.infer_compiler import DeployableModule _nested_counter = defaultdict(lambda: 0) diff --git a/src/onediff/utils/__init__.py b/src/onediff/utils/__init__.py index 631812a59..6b16746d6 100644 --- a/src/onediff/utils/__init__.py +++ b/src/onediff/utils/__init__.py @@ -1,7 +1,7 @@ -from .log_utils import logger from .env_var import ( parse_boolean_from_env, - set_boolean_env_var, parse_integer_from_env, + set_boolean_env_var, set_integer_env_var, ) +from .log_utils import logger diff --git a/src/onediff/utils/chache_utils.py b/src/onediff/utils/chache_utils.py index 72c2b73f6..6aca1254d 100644 --- a/src/onediff/utils/chache_utils.py +++ b/src/onediff/utils/chache_utils.py @@ -4,7 +4,7 @@ class LRUCache(collections.OrderedDict): __slots__ = ["LEN"] - def __init__(self, capacity: int=9): + def __init__(self, capacity: int = 9): self.LEN = capacity def get(self, key: str, default=None) -> any: diff --git a/src/onediff/utils/import_utils.py b/src/onediff/utils/import_utils.py index 9d7e8b1d4..111387966 100644 --- a/src/onediff/utils/import_utils.py +++ b/src/onediff/utils/import_utils.py @@ -1,8 +1,8 @@ import importlib -import traceback -from inspect import ismodule import os import platform +import traceback +from inspect import ismodule from types import ModuleType system = platform.system() diff --git a/src/onediff/utils/log_utils.py b/src/onediff/utils/log_utils.py index 58bb22d4a..d9aef839d 100644 --- a/src/onediff/utils/log_utils.py +++ b/src/onediff/utils/log_utils.py @@ -1,6 +1,6 @@ -import time import logging import os +import time from pathlib import Path diff --git a/tests/comfyui/meta2png.py b/tests/comfyui/meta2png.py index 25034badc..989e666d8 100644 --- a/tests/comfyui/meta2png.py +++ b/tests/comfyui/meta2png.py @@ -25,7 +25,10 @@ def parse_args(): required=True, ) parser.add_argument( - "--output", type=str, default="", help="The output png filename.", + "--output", + type=str, + default="", + help="The output png filename.", ) parser.add_argument( "--key", diff --git a/tests/comfyui/test_by_api.py b/tests/comfyui/test_by_api.py index 686e6889d..13e03fdff 100644 --- a/tests/comfyui/test_by_api.py +++ b/tests/comfyui/test_by_api.py @@ -2,7 +2,6 @@ import json import os - import requests from PIL import Image diff --git a/tests/comfyui/test_by_ui.py b/tests/comfyui/test_by_ui.py index 9ee92836a..79bea534e 100644 --- a/tests/comfyui/test_by_ui.py +++ b/tests/comfyui/test_by_ui.py @@ -1,5 +1,5 @@ """ -Before running this script, you need to start the Selenium and ComfyUI services. +Before running this script, you need to start the Selenium and ComfyUI services. You can start their containers using Docker Compose. Please set the following environment variables (whose values are for reference only): @@ -19,7 +19,7 @@ **Note**: It is advisable to execute the following commands in the 'diffusers' directory -unless you are fully aware of the implications of executing them in a different +unless you are fully aware of the implications of executing them in a different directory. git clone https://github.com/comfyanonymous/ComfyUI.git @@ -44,28 +44,44 @@ from PIL import Image from selenium import webdriver +from selenium.common.exceptions import TimeoutException from selenium.webdriver.common.by import By -from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC -from selenium.common.exceptions import TimeoutException +from selenium.webdriver.support.ui import WebDriverWait def parse_args(): parser = argparse.ArgumentParser(description="Test ComfyUI workflow by Selenium.") parser.add_argument( - "-w", "--workflow", type=str, required=True, help="Workflow file", + "-w", + "--workflow", + type=str, + required=True, + help="Workflow file", ) parser.add_argument( - "-t", "--timeout", type=int, default="200", + "-t", + "--timeout", + type=int, + default="200", ) parser.add_argument( - "--host", type=str, default="127.0.0.1", help="The selenium service host", + "--host", + type=str, + default="127.0.0.1", + help="The selenium service host", ) parser.add_argument( - "--port", type=str, default="4444", help="The selenium service port", + "--port", + type=str, + default="4444", + help="The selenium service port", ) parser.add_argument( - "--comfy_port", type=str, default="8188", help="The ComfyUI service port", + "--comfy_port", + type=str, + default="8188", + help="The ComfyUI service port", ) args = parser.parse_args() return args @@ -176,12 +192,12 @@ def launch_prompt(driver): launch_and_wait(driver, timeout=args.timeout) duration = time.time() - start_time - print( - f"{args.workflow} has finished, time elapsed: {duration:.1f}" - ) - + print(f"{args.workflow} has finished, time elapsed: {duration:.1f}") + if duration < 2: - raise ValueError("Execution duration is too short, possible error in workflow execution") + raise ValueError( + "Execution duration is too short, possible error in workflow execution" + ) print(f"check if error occurs...") check_error_occurs(driver) diff --git a/tests/comfyui/workflows/sdxl-control-lora-speedup.json b/tests/comfyui/workflows/sdxl-control-lora-speedup.json index 927bea232..b1ed8bd82 100644 --- a/tests/comfyui/workflows/sdxl-control-lora-speedup.json +++ b/tests/comfyui/workflows/sdxl-control-lora-speedup.json @@ -566,4 +566,4 @@ "config": {}, "extra": {}, "version": 0.4 -} \ No newline at end of file +} diff --git a/tests/comfyui/workflows/sdxl-unet-speedup-graph-saver.json b/tests/comfyui/workflows/sdxl-unet-speedup-graph-saver.json index 1af8e9adb..04ab7bfb0 100644 --- a/tests/comfyui/workflows/sdxl-unet-speedup-graph-saver.json +++ b/tests/comfyui/workflows/sdxl-unet-speedup-graph-saver.json @@ -476,4 +476,4 @@ } }, "version": 0.4 -} \ No newline at end of file +} diff --git a/tests/comfyui/workflows/text-to-video-speedup.json b/tests/comfyui/workflows/text-to-video-speedup.json index dc4e9b7cb..e62a8624b 100644 --- a/tests/comfyui/workflows/text-to-video-speedup.json +++ b/tests/comfyui/workflows/text-to-video-speedup.json @@ -824,4 +824,4 @@ "config": {}, "extra": {}, "version": 0.4 -} \ No newline at end of file +} diff --git a/tests/convert_torch_to_of/test_patch_for_compiling.py b/tests/convert_torch_to_of/test_patch_for_compiling.py index 21844a4fa..1b7a4ad73 100644 --- a/tests/convert_torch_to_of/test_patch_for_compiling.py +++ b/tests/convert_torch_to_of/test_patch_for_compiling.py @@ -4,9 +4,11 @@ Uasge: python -m pytest diffusers/tests/torch_to_oflow/test_temp_fix_compile_impl.py """ -import pytest import numpy as np -from onediff.infer_compiler.backends.oneflow.import_tools.patch_for_compiler import FakeCuda +import pytest +from onediff.infer_compiler.backends.oneflow.import_tools.patch_for_compiler import ( + FakeCuda, +) @pytest.mark.parametrize("batch_size", [8]) @@ -43,7 +45,7 @@ def torch_flash_attention() -> np.ndarray: return result.cpu().detach().numpy() def oneflow_flash_attention() -> np.ndarray: - import oneflow as flow + import oneflow as flow # usort: skip q = flow.tensor(query, dtype=flow.float16).to("cuda") k = flow.tensor(key, dtype=flow.float16).to("cuda") diff --git a/tests/convert_torch_to_of/test_torch2of_demo.py b/tests/convert_torch_to_of/test_torch2of_demo.py index df4eb5202..d269a2eaa 100644 --- a/tests/convert_torch_to_of/test_torch2of_demo.py +++ b/tests/convert_torch_to_of/test_torch2of_demo.py @@ -5,8 +5,9 @@ python -m pytest test_torch2of_demo.py """ import torch -import oneflow as flow +import oneflow as flow # usort: skip import unittest + import numpy as np from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import transform_mgr @@ -39,7 +40,7 @@ def forward(self, x): def apply_model(self, x): return self.forward(x) - + class TestTorch2ofDemo(unittest.TestCase): def judge_tensor_func(self, y_pt, y_of): @@ -48,7 +49,7 @@ def judge_tensor_func(self, y_pt, y_of): y_pt = y_pt.cpu().detach().numpy() y_of = y_of.cpu().detach().numpy() assert np.allclose(y_pt, y_of, atol=1e-3, rtol=1e-3) - + def test_torch2of_demo(self): # Register PyTorch model to OneDiff cls_key = transform_mgr.get_transformed_entity_name(PyTorchModel) diff --git a/tests/sd-webui/test_api.py b/tests/sd-webui/test_api.py index f4cee6601..74c751b7b 100644 --- a/tests/sd-webui/test_api.py +++ b/tests/sd-webui/test_api.py @@ -1,25 +1,26 @@ import os +from pathlib import Path + import numpy as np import pytest -from pathlib import Path from PIL import Image from utils import ( - IMG2IMG_API_ENDPOINT, - OPTIONS_API_ENDPOINT, - SAVED_GRAPH_NAME, - TXT2IMG_API_ENDPOINT, - WEBUI_SERVER_URL, cal_ssim, check_and_generate_images, + dump_image, get_all_args, get_base_args, get_data_summary, get_image_array_from_response, get_target_image_filename, + get_threshold, + IMG2IMG_API_ENDPOINT, is_txt2img, + OPTIONS_API_ENDPOINT, post_request_and_check, - dump_image, - get_threshold, + SAVED_GRAPH_NAME, + TXT2IMG_API_ENDPOINT, + WEBUI_SERVER_URL, ) @@ -99,8 +100,8 @@ def test_onediff_load_graph(url_txt2img): def test_onediff_refiner(url_txt2img): extra_args = { "sd_model_checkpoint": "sd_xl_base_1.0.safetensors", - "refiner_checkpoint" :"sd_xl_refiner_1.0.safetensors [7440042bbd]", - "refiner_switch_at" : 0.8, + "refiner_checkpoint": "sd_xl_refiner_1.0.safetensors [7440042bbd]", + "refiner_switch_at": 0.8, } data = {**get_base_args(), **extra_args} # loop 3 times for checking model switching between base and refiner diff --git a/tests/sd-webui/utils.py b/tests/sd-webui/utils.py index 3a1bbaedd..b5fc55671 100644 --- a/tests/sd-webui/utils.py +++ b/tests/sd-webui/utils.py @@ -148,9 +148,10 @@ def get_data_summary(data: Dict[str, Any]) -> Dict[str, bool]: def dump_image(src_img: np.ndarray, target_img: np.ndarray, filename: str): - combined_img = np.concatenate((src_img, target_img), axis=1) - image = Image.fromarray(combined_img) - image.save(f'{filename}.png') + combined_img = np.concatenate((src_img, target_img), axis=1) + image = Image.fromarray(combined_img) + image.save(f"{filename}.png") + def get_threshold(data: Dict[str, Any]): if is_quant(data): diff --git a/tests/test_dual_module_list.py b/tests/test_dual_module_list.py index 28a711404..22a01ffaf 100644 --- a/tests/test_dual_module_list.py +++ b/tests/test_dual_module_list.py @@ -1,9 +1,9 @@ import numpy as np -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import register import torch import torch.nn as nn -import oneflow as flow +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler.backends.oneflow.transform import register +import oneflow as flow # usort: skip class MyModule(nn.Module): @@ -39,7 +39,10 @@ def forward(self, x): assert np.allclose(y_torch.detach().cpu(), y_oneflow.detach().cpu(), 1e-03, 1e-03) -from onediff.infer_compiler.backends.oneflow.dual_module import DualModule, DualModuleList +from onediff.infer_compiler.backends.oneflow.dual_module import ( + DualModule, + DualModuleList, +) assert isinstance(m.linears, DualModuleList) diff --git a/tests/test_model_inference.py b/tests/test_model_inference.py index 2996a5b38..fb15996a7 100644 --- a/tests/test_model_inference.py +++ b/tests/test_model_inference.py @@ -3,10 +3,11 @@ import unittest from functools import partial -import torch -from onediff.utils.import_utils import is_oneflow_available, is_nexfort_available import onediff.infer_compiler as infer_compiler +import torch +from onediff.utils.import_utils import is_nexfort_available, is_oneflow_available + class SubModule(torch.nn.Module): def __init__(self): @@ -36,23 +37,29 @@ def compute(x): class TestModelInference(unittest.TestCase): def setUp(self) -> None: self.compilation_functions = [] - + if is_oneflow_available(): oneflow_compile_fn = partial(infer_compiler.compile, backend="oneflow") self.compilation_functions.append(oneflow_compile_fn) - + if is_nexfort_available(): nexfort_compile_options = { "mode": "max-optimize:max-autotune:freezing:benchmark:cudagraphs", "dynamic": True, "fullgraph": True, } - nexfort_compile_fn = partial(infer_compiler.compile, backend="nexfort", options=nexfort_compile_options) + nexfort_compile_fn = partial( + infer_compiler.compile, + backend="nexfort", + options=nexfort_compile_options, + ) self.compilation_functions.append(nexfort_compile_fn) assert len(self.compilation_functions) > 0 - def measure_inference_time(self, model, warmup=3, num_runs=30, input_args=[], input_kwargs={}): + def measure_inference_time( + self, model, warmup=3, num_runs=30, input_args=[], input_kwargs={} + ): for _ in range(warmup): model(*input_args, **input_kwargs) @@ -78,23 +85,37 @@ def generate_models_and_inputs(self): compiled_model_sub = copy.deepcopy(model) compiled_model_sub.sub_module = compile_fn(compiled_model_sub.sub_module) yield model, compiled_model_sub, inputs, {} - - if compile_fn.keywords.get('backend') == "nexfort": + + if compile_fn.keywords.get("backend") == "nexfort": inputs_compute = [torch.randn(10000, 1000).cuda().half()] compiled_compute_fn = compile_fn(compute) yield compute, compiled_compute_fn, inputs_compute, {} @torch.inference_mode() def test_inference_results(self): - for model, compiled_model, input_args, input_kwargs in self.generate_models_and_inputs(): - original_result, _ = self.measure_inference_time(model, input_args=input_args, input_kwargs=input_kwargs) - compiled_result, _ = self.measure_inference_time(compiled_model, input_args=input_args, input_kwargs=input_kwargs) + for ( + model, + compiled_model, + input_args, + input_kwargs, + ) in self.generate_models_and_inputs(): + original_result, _ = self.measure_inference_time( + model, input_args=input_args, input_kwargs=input_kwargs + ) + compiled_result, _ = self.measure_inference_time( + compiled_model, input_args=input_args, input_kwargs=input_kwargs + ) + + self.assertTrue( + torch.allclose(original_result, compiled_result, atol=1e-2, rtol=1e-3) + ) - self.assertTrue(torch.allclose(original_result, compiled_result, atol=1e-2, rtol=1e-3)) - if isinstance(model, torch.nn.Module): self.assertIsInstance(compiled_model, MainModule) - self.assertEqual(set(model.state_dict().keys()), set(compiled_model.state_dict().keys())) + self.assertEqual( + set(model.state_dict().keys()), + set(compiled_model.state_dict().keys()), + ) if __name__ == "__main__": diff --git a/tests/test_pipelines_oneflow_img2img.py b/tests/test_pipelines_oneflow_img2img.py index d541ae370..44c1bf5b0 100644 --- a/tests/test_pipelines_oneflow_img2img.py +++ b/tests/test_pipelines_oneflow_img2img.py @@ -20,8 +20,6 @@ import numpy as np import oneflow as torch -from onediff import OneFlowStableDiffusionImg2ImgPipeline - from diffusers import ( AutoencoderKL, LMSDiscreteScheduler, @@ -30,6 +28,8 @@ ) from diffusers.utils import floats_tensor, load_image, torch_device + +from onediff import OneFlowStableDiffusionImg2ImgPipeline from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -265,7 +265,9 @@ def test_stable_diffusion_img2img_pipeline(self): model_id = "CompVis/stable-diffusion-v1-4" pipe = OneFlowStableDiffusionImg2ImgPipeline.from_pretrained( - model_id, safety_checker=self.dummy_safety_checker, use_auth_token=True, + model_id, + safety_checker=self.dummy_safety_checker, + use_auth_token=True, ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_quantitative_quality.py b/tests/test_quantitative_quality.py index 9c3b90dde..bb3fca845 100644 --- a/tests/test_quantitative_quality.py +++ b/tests/test_quantitative_quality.py @@ -1,18 +1,25 @@ -from skimage.metrics import structural_similarity as ssim -from PIL import Image -import numpy as np import unittest -class QuantizeQuality(unittest.TestCase): +import numpy as np +from PIL import Image +from skimage.metrics import structural_similarity as ssim + +class QuantizeQuality(unittest.TestCase): def test_validate(self): - image1 = np.array(Image.open('/share_nfs/civitai/20240407-163408.jpg').convert('RGB')) - image2 = np.array(Image.open('/src/onediff/output_enterprise_sd.png').convert('RGB')) + image1 = np.array( + Image.open("/share_nfs/civitai/20240407-163408.jpg").convert("RGB") + ) + image2 = np.array( + Image.open("/src/onediff/output_enterprise_sd.png").convert("RGB") + ) # Calculate SSIM ssim_index = ssim(image1, image2, multichannel=True, win_size=3) print("SSIM:", ssim_index) - self.assertTrue(ssim_index > 0.89, "SSIM Validation fails, and the workflow is aborted") - + self.assertTrue( + ssim_index > 0.89, "SSIM Validation fails, and the workflow is aborted" + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_quantize_custom_model.py b/tests/test_quantize_custom_model.py index 00a2fbce5..6e4de6fcf 100644 --- a/tests/test_quantize_custom_model.py +++ b/tests/test_quantize_custom_model.py @@ -2,13 +2,15 @@ import os import unittest -import oneflow as flow +import oneflow as flow # usort: skip import torch -from torch import nn from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import register -from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version +from onediff.infer_compiler.backends.oneflow.utils.version_util import ( + is_community_version, +) +from torch import nn is_community = is_community_version() onediff_quant_spec = importlib.util.find_spec("onediff_quant") @@ -17,10 +19,10 @@ exit(0) from onediff_quant.quantization import ( + create_quantization_calculator, OfflineQuantModule, OnlineQuantModule, QuantizationConfig, - create_quantization_calculator, )