From b97a31b84ae9597cb8626fbf6b1ba2ea9e9a3993 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Tue, 26 Mar 2024 02:01:16 +0000 Subject: [PATCH 01/11] test_quantize_custom_model --- src/onediff/quantization/README.md | 0 src/onediff/quantization/quantize_utils.py | 15 +-- tests/test_quantize_custom_model.py | 117 +++++++++++++++++++++ 3 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 src/onediff/quantization/README.md create mode 100644 tests/test_quantize_custom_model.py diff --git a/src/onediff/quantization/README.md b/src/onediff/quantization/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/onediff/quantization/quantize_utils.py b/src/onediff/quantization/quantize_utils.py index 97c7c72d7..678787586 100644 --- a/src/onediff/quantization/quantize_utils.py +++ b/src/onediff/quantization/quantize_utils.py @@ -12,19 +12,12 @@ def setup_onediff_quant(): def load_calibration_and_quantize_pipeline(calibration_path, pipe): - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - + from onediff_quant.quantization import CalibrationStorage from onediff_quant.utils import replace_sub_module_with_quantizable_module + store = CalibrationStorage() + calibrate_info = store.load_from_file(file_path=calibration_path) + for sub_module_name, sub_calibrate_info in calibrate_info.items(): replace_sub_module_with_quantizable_module( pipe.unet, diff --git a/tests/test_quantize_custom_model.py b/tests/test_quantize_custom_model.py new file mode 100644 index 000000000..7e354fe90 --- /dev/null +++ b/tests/test_quantize_custom_model.py @@ -0,0 +1,117 @@ +import os +import unittest + +import oneflow as flow +import torch +from onediff_quant.quantization import (OfflineQuantModule, OnlineQuantModule, + QuantizationConfig, + QuantizationStatsStorage, + create_quantization_calculator) +from torch import nn + +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler.transform import register + + +# Define the model +class SimpleModel(nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + # Two convolutional layers + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = nn.Conv2d( + in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1 + ) + # Fully connected layer + self.fc = nn.Linear( + 32 * 32 * 32, 10 + ) # Input channels are 32*32*32, output 10 classes + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = torch.relu(self.conv2(x)) + x = x.view(x.size(0), -1) # Flatten the tensor into one dimension + x = self.fc(x) + return x + + +class SimpleModel_OF(flow.nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + # Two convolutional layers + self.conv1 = flow.nn.Conv2d( + in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.conv2 = flow.nn.Conv2d( + in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1 + ) + # Fully connected layer + self.fc = flow.nn.Linear( + 32 * 32 * 32, 10 + ) # Input channels are 32*32*32, output 10 classes + + def forward(self, x): + x = flow.relu(self.conv1(x)) + x = flow.relu(self.conv2(x)) + x = x.view(x.size(0), -1) # Flatten the tensor into one dimension + x = self.fc(x) + return x + + +register(torch2oflow_class_map={SimpleModel: SimpleModel_OF}) + +# Configure quantization +config = QuantizationConfig.from_settings( + quantize_conv=True, + quantize_linear=False, + cache_dir="runs", + plot_calibrate_info=True, +) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SimpleModel().to(device).to(torch.float16) +input_data = torch.randn(1, 3, 32, 32, dtype=torch.float16).to( + device +) # Input data size is [batch_size, channels, height, width] +seed = 1 +calculator = create_quantization_calculator(model, config) +torch.manual_seed(seed) +standard_output = model(input_data) + + +class TestOnlineQuantModule(unittest.TestCase): + def setUp(self): + self.module = OnlineQuantModule(calculator, inplace=False) + + def test_quantize_with_calibration(self): + quantized_model, info = self.module.quantize_with_calibration(input_data) + status = self.module.collect_quantization_status(quantized_model, info) + store = QuantizationStatsStorage(config.cache_dir) + store.save(data=status, file_name="quant_stats.json") + compiled_model = oneflow_compile(quantized_model) + torch.manual_seed(seed) + quantized_output = compiled_model(input_data) + # print(f'{quantized_output=} \n {standard_output=}') + self.assertTrue(torch.allclose(standard_output, quantized_output, 1e4, 1e4)) + + +class TestOfflineQuantModule(unittest.TestCase): + def setUp(self): + self.module = OfflineQuantModule(calculator, inplace=False) + + def test_quantize_with_calibration(self): + quantized_model = self.module.quantize_with_calibration(input_data)[0] + file_path = os.path.join(config.cache_dir, "quantized_model.pt") + self.module.save(quantized_model, file_path) + + quantized_model = OfflineQuantModule(None).load(file_path=file_path) + compiled_model = oneflow_compile(quantized_model) + + torch.manual_seed(seed) + quantized_output = compiled_model(input_data) + self.assertTrue(torch.allclose(standard_output, quantized_output, 1e4, 1e4)) + + +if __name__ == "__main__": + unittest.main() From b1834b761145616a82ae9481abd7f78eb89776e5 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Tue, 26 Mar 2024 06:20:09 +0000 Subject: [PATCH 02/11] refine --- tests/test_quantize_custom_model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_quantize_custom_model.py b/tests/test_quantize_custom_model.py index 7e354fe90..640ec137d 100644 --- a/tests/test_quantize_custom_model.py +++ b/tests/test_quantize_custom_model.py @@ -26,8 +26,8 @@ def __init__(self): ) # Fully connected layer self.fc = nn.Linear( - 32 * 32 * 32, 10 - ) # Input channels are 32*32*32, output 10 classes + 32 * 32 * 32, 4 + ) # Input channels are 32*32*32, output 4 classes def forward(self, x): x = torch.relu(self.conv1(x)) @@ -49,7 +49,7 @@ def __init__(self): ) # Fully connected layer self.fc = flow.nn.Linear( - 32 * 32 * 32, 10 + 32 * 32 * 32, 4 ) # Input channels are 32*32*32, output 10 classes def forward(self, x): @@ -65,7 +65,7 @@ def forward(self, x): # Configure quantization config = QuantizationConfig.from_settings( quantize_conv=True, - quantize_linear=False, + quantize_linear=True, cache_dir="runs", plot_calibrate_info=True, ) @@ -86,13 +86,14 @@ def setUp(self): def test_quantize_with_calibration(self): quantized_model, info = self.module.quantize_with_calibration(input_data) - status = self.module.collect_quantization_status(quantized_model, info) + status = self.module.collect_quantization_status(model, info) + assert status["quantized_conv_count"] == 2 and status["quantized_linear_count"] == 1 store = QuantizationStatsStorage(config.cache_dir) store.save(data=status, file_name="quant_stats.json") compiled_model = oneflow_compile(quantized_model) torch.manual_seed(seed) quantized_output = compiled_model(input_data) - # print(f'{quantized_output=} \n {standard_output=}') + # print(f'{quantized_output=} \n{standard_output=}') self.assertTrue(torch.allclose(standard_output, quantized_output, 1e4, 1e4)) From 95312a45ca7bb6c04647804bbe21a2690f6f5f11 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Tue, 26 Mar 2024 06:42:55 +0000 Subject: [PATCH 03/11] refine --- .github/workflows/examples.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index abc6efa5c..b0bca4910 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -207,6 +207,8 @@ jobs: if: matrix.test-suite == 'diffusers_examples' run: | docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 -m pip install -e . + - if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro') + run: docker exec -w /src/onediff ${{ env.CONTAINER_NAME }} python3 tests/test_quantize_custom_model.py - if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro') run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_enterprise.py --model /share_nfs/hf_models/stable-diffusion-xl-base-1.0-int8 --width 512 --height 512 --saved_image output_enterprise_sdxl.png - if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro') From b9e8819990917dfa33eb54f5f627a93c650a6885 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Thu, 28 Mar 2024 05:17:43 +0000 Subject: [PATCH 04/11] refine --- .../utils/online_quantization_utils.py | 51 +++++++++++++++++++ .../infer_compiler/with_oneflow_compile.py | 30 ++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 src/onediff/infer_compiler/utils/online_quantization_utils.py diff --git a/src/onediff/infer_compiler/utils/online_quantization_utils.py b/src/onediff/infer_compiler/utils/online_quantization_utils.py new file mode 100644 index 000000000..8f15b9f83 --- /dev/null +++ b/src/onediff/infer_compiler/utils/online_quantization_utils.py @@ -0,0 +1,51 @@ +def online_quantize_model( + model, input_args, input_kwargs, + seed=1, inplace=True, + module_selector=lambda x: x, + quant_config = None, + calibration_info=None, +): + """Optimize the quantization pipeline. + + Returns: + tuple: A tuple containing the quantized model and the quantization + status. + """ + + from onediff_quant.quantization import ( + OnlineQuantModule, + create_quantization_calculator, + ) + + calculator = create_quantization_calculator( + model, quant_config, module_selector, seed, + calibration_info=calibration_info, + ) + + module = OnlineQuantModule(calculator, False, inplace=inplace) + + quantized_model, info = module.quantize_with_calibration( + *input_args, **input_kwargs + ) + status = module.collect_quantization_status(model, info) + + return quantized_model, status + + +def quantize_and_deploy_wrapper(func): + def wrapper(self: "DeployableModule", *args, **kwargs): + torch_model = self._torch_module + quant_config = self._deployable_module_quant_config + if quant_config: + torch_model, _ = online_quantize_model( + torch_model, args, kwargs, + module_selector=lambda x: x, + quant_config=quant_config, + inplace=True, + ) + self._deployable_module_quant_config = None + output = func(self, *args, **kwargs) + return output + return wrapper + + \ No newline at end of file diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py index b3e4d2906..8f546f1d1 100644 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ b/src/onediff/infer_compiler/with_oneflow_compile.py @@ -15,7 +15,7 @@ from .utils.cost_util import cost_cnt from .utils.param_utils import parse_device, check_device from .utils.graph_management_utils import graph_file_management - +from .utils.online_quantization_utils import quantize_and_deploy_wrapper class DualModule(torch.nn.Module): def __init__(self, torch_module, oneflow_module): @@ -210,6 +210,7 @@ def __init__( self._deployable_module_enable_dynamic = dynamic self._deployable_module_options = options self._deployable_module_dpl_graph = None + self._deployable_module_quant_config = None self._is_raw_deployable_module = True self._load_graph_first_run = True @@ -249,6 +250,7 @@ def get_graph(self): ) return self._deployable_module_dpl_graph + @input_output_processor @handle_deployable_exception @graph_file_management @@ -263,7 +265,8 @@ def apply_model(self, *args, **kwargs): *args, **kwargs ) return output - + + @quantize_and_deploy_wrapper @input_output_processor @handle_deployable_exception @graph_file_management @@ -352,6 +355,29 @@ def _clear_old_graph(self): def get_graph_file(self): return self._deployable_module_options.get("graph_file", None) + + + def apply_online_quant(self, quant_config): + """ + Applies the provided quantization configuration for online use. + + Args: + quant_config (QuantizationConfig): The quantization configuration to apply. + + Example: + >>> from onediff_quant.quantization import QuantizationConfig + >>> quant_config = QuantizationConfig.from_settings( + ... quantize_conv=True, + ... quantize_linear=True, + ... conv_mae_threshold=0.005, + ... linear_mae_threshold=0.005, + ... conv_compute_density_threshold=300, + ... linear_compute_density_threshold=100, + ... cache_dir=args.cache_dir) + >>> model.apply_online_quant(quant_config) + """ + self._deployable_module_quant_config = quant_config + class OneflowGraph(flow.nn.Graph): From bca183c5512ae5596a6ea3a4eef5d95abb12de29 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Tue, 2 Apr 2024 07:23:52 +0000 Subject: [PATCH 05/11] add quantization/README.md --- .../examples/text_to_image_online_quant.py | 123 ++++++++++++++++++ src/onediff/quantization/README.md | 7 + tests/test_quantize_custom_model.py | 30 +++-- 3 files changed, 151 insertions(+), 9 deletions(-) create mode 100644 onediff_diffusers_extensions/examples/text_to_image_online_quant.py diff --git a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py new file mode 100644 index 000000000..ffc1747a5 --- /dev/null +++ b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py @@ -0,0 +1,123 @@ +"""[Stable Diffusion V1.5 - Hugging Face Model Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5) + +## Performance Comparison + +Updated on Tue 02 Apr 2024 + +Timings for 50 steps at 512x512 +| Accelerator | Baseline (non-optimized) | OneDiff Quant(optimized) | Percentage improvement | +| ----------------------- | ------------------------ | ------------------------ | ---------------------- | +| NVIDIA GeForce RTX 3090 | 2.51 s | 0.92 s | ~63 % | + +## Install + +1. [OneDiff Installation Guide](https://github.com/siliconflow/onediff/tree/main?tab=readme-ov-file#installation) +2. [OneDiffx Installation Guide](https://github.com/siliconflow/onediff/tree/main/onediff_diffusers_extensions#install-and-setup) + +## Usage: + +```shell +# Baseline (non-optimized) +$ python text_to_image_online_quant.py \ + --model_id /share_nfs/hf_models/stable-diffusion-v1-5 \ + --seed 1 \ + --backend torch --height 512 --width 512 --output_file sd-v1-5_torch.png +``` +```shell +# OneDiff Quant(optimized) +$ python text_to_image_online_quant.py \ + --model_id /share_nfs/hf_models/stable-diffusion-v1-5 \ + --seed 1 \ + --backend onediff \ + --cache_dir ./run_sd-v1-5_quant \ + --height 512 \ + --width 512 \ + --output_file sd-v1-5_quant.png \ + --quantize \ + --conv_mae_threshold 0.1 \ + --linear_mae_threshold 0.2 \ + --conv_compute_density_threshold 900 \ + --linear_compute_density_threshold 300 +``` + +| Option | Range | Default | Description | +| -------------------------------------- | ------ | ------- | ---------------------------------------------------------------------------- | +| --conv_mae_threshold 0.9 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. | +| --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | +| --conv_compute_density_threshold 900 | [1, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. | +| --linear_compute_density_threshold 300 | [1, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. | + +Notes: + +1. Set CUDA device using export CUDA_VISIBLE_DEVICES=7. + +2. The log *.pt file is cached. Quantization result information can be found in `cache_dir`/quantization_stats.json. + +""" +import argparse +import time +import torch +from diffusers import AutoPipelineForText2Image +from onediffx import compile_pipe +from onediff_quant.quantization import QuantizationConfig + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", default="runwayml/stable-diffusion-v1-5") + parser.add_argument("--prompt", default="a photo of an astronaut riding a horse on mars") + parser.add_argument("--output_file", default="astronaut_rides_horse_onediff_quant.png") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument("--backend", default="onediff", choices=["onediff", "torch"]) + parser.add_argument("--quantize", action="store_true") + parser.add_argument("--cache_dir", default="./run_sd-v1-5") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--conv_mae_threshold", type=float, default=0.2) + parser.add_argument("--linear_mae_threshold", type=float, default=0.4) + parser.add_argument("--conv_compute_density_threshold", type=int, default=900) + parser.add_argument("--linear_compute_density_threshold", type=int, default=300) + return parser.parse_args() + +def load_model(model_id): + pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16") + pipe.to(f"cuda") + return pipe + +def compile_and_quantize_model(pipe, cache_dir, quantize, quant_params): + pipe = compile_pipe(pipe) + if quantize: + config = QuantizationConfig.from_settings(**quant_params, cache_dir=cache_dir, plot_calibrate_info=True) + pipe.unet.apply_online_quant(quant_config=config) + return pipe + +def save_image(image, output_file): + image.save(output_file) + print(f"Image saved to: {output_file}") + +def main(): + args = parse_args() + pipe = load_model(args.model_id) + if args.backend == "onediff": + compile_and_quantize_model(pipe, args.cache_dir, args.quantize, + {"conv_mae_threshold": args.conv_mae_threshold, + "linear_mae_threshold": args.linear_mae_threshold, + "conv_compute_density_threshold": args.conv_compute_density_threshold, + "linear_compute_density_threshold": args.linear_compute_density_threshold}) + + # Warm-up + pipe(prompt=args.prompt, num_inference_steps=1) + + # Run_inference + for _ in range(4): + start_time = time.time() + torch.manual_seed(args.seed) + image = pipe(prompt=args.prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps).images[0] + end_time = time.time() + print(f"Inference time: {end_time - start_time:.2f} seconds") + + # [onediff_quant.png](https://github.com/siliconflow/onediff/assets/109639975/75cd9407-c9bb-423f-9e70-c15df76ff2b1) + save_image(image, args.output_file) + +if __name__ == "__main__": + main() diff --git a/src/onediff/quantization/README.md b/src/onediff/quantization/README.md index e69de29bb..72b6d8cc8 100644 --- a/src/onediff/quantization/README.md +++ b/src/onediff/quantization/README.md @@ -0,0 +1,7 @@ +##
OneDiff Quant 🚀 NEW
+##
Documentation
+- [Installation Guide](https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#install-onediff-enterprise) +- [How to use Online Quant](../../../onediff_diffusers_extensions/examples/text_to_image_online_quant.py) +- [How to use Offline Quant](./quantize_pipeline.py) +- [How to Quant a custom model](../../../tests/test_quantize_custom_model.py) +- [Community and Support](https://github.com/siliconflow/onediff?tab=readme-ov-file#community-and-support) \ No newline at end of file diff --git a/tests/test_quantize_custom_model.py b/tests/test_quantize_custom_model.py index 640ec137d..36d92f9fc 100644 --- a/tests/test_quantize_custom_model.py +++ b/tests/test_quantize_custom_model.py @@ -1,16 +1,27 @@ +import importlib import os import unittest import oneflow as flow import torch -from onediff_quant.quantization import (OfflineQuantModule, OnlineQuantModule, - QuantizationConfig, - QuantizationStatsStorage, - create_quantization_calculator) from torch import nn from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.transform import register +from onediff.infer_compiler.utils import is_community_version + +is_community = is_community_version() +onediff_quant_spec = importlib.util.find_spec("onediff_quant") +if is_community or onediff_quant_spec is None: + print(f"{is_community=} {onediff_quant_spec=}") + exit(0) + +from onediff_quant.quantization import ( + OfflineQuantModule, + OnlineQuantModule, + QuantizationConfig, + create_quantization_calculator, +) # Define the model @@ -66,8 +77,8 @@ def forward(self, x): config = QuantizationConfig.from_settings( quantize_conv=True, quantize_linear=True, - cache_dir="runs", - plot_calibrate_info=True, + cache_dir="cache_dir", + plot_calibrate_info=False, ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleModel().to(device).to(torch.float16) @@ -87,9 +98,10 @@ def setUp(self): def test_quantize_with_calibration(self): quantized_model, info = self.module.quantize_with_calibration(input_data) status = self.module.collect_quantization_status(model, info) - assert status["quantized_conv_count"] == 2 and status["quantized_linear_count"] == 1 - store = QuantizationStatsStorage(config.cache_dir) - store.save(data=status, file_name="quant_stats.json") + assert ( + status["quantized_conv_count"] == 2 + and status["quantized_linear_count"] == 1 + ) compiled_model = oneflow_compile(quantized_model) torch.manual_seed(seed) quantized_output = compiled_model(input_data) From e8287ab95a178b25543d9ab5adc30de11cb6b61f Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:39:33 +0800 Subject: [PATCH 06/11] Update text_to_image_online_quant.py --- .../examples/text_to_image_online_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py index ffc1747a5..6af04e13a 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py +++ b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py @@ -44,8 +44,8 @@ | -------------------------------------- | ------ | ------- | ---------------------------------------------------------------------------- | | --conv_mae_threshold 0.9 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. | | --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | -| --conv_compute_density_threshold 900 | [1, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. | -| --linear_compute_density_threshold 300 | [1, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. | +| --conv_compute_density_threshold 900 | [0, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. | +| --linear_compute_density_threshold 300 | [0, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. | Notes: From e2d853e68e3eca932fadc5e38c629b70cc9c8d93 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Tue, 2 Apr 2024 09:34:24 +0000 Subject: [PATCH 07/11] refine --- onediff_comfy_nodes/_nodes.py | 51 +++++++--------- .../modules/optimizer_strategy/__init__.py | 2 +- .../optimizer_strategy/deepcache_optimizer.py | 17 ++++-- .../optimizer_strategy/optimizer_strategy.py | 13 +++- .../quantization_optimizer.py | 60 +++++++++++++++---- 5 files changed, 96 insertions(+), 47 deletions(-) diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 26476a0b1..3af392fe8 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -388,10 +388,18 @@ def deep_cache_convert( from .modules.optimizer_strategy import OptimizerStrategy, DeepcacheOptimizerExecutor, OnelineQuantizationOptimizerExecutor class OneDiffControlNetLoader(ControlNetLoader): + @classmethod + def INPUT_TYPES(s): + ret = super().INPUT_TYPES() + ret.update({"optional": { + "model_optimizer": ("MODEL_OPTIMIZER",),} + }) + return ret + CATEGORY = "OneDiff/Loaders" FUNCTION = "onediff_load_controlnet" - def onediff_load_controlnet(self, control_net_name): + def onediff_load_controlnet(self, control_net_name, model_optimizer=None): controlnet = super().load_controlnet(control_net_name)[0] load_device = model_management.get_torch_device() @@ -408,12 +416,15 @@ def gen_compile_options(model): ) return (controlnet,) elif isinstance(controlnet, ControlNet): - control_model = controlnet.control_model - compile_options = gen_compile_options(control_model) - control_model = control_model.to(load_device) - controlnet.control_model = oneflow_compile( - control_model, options=compile_options - ) + if model_optimizer: + controlnet = model_optimizer(controlnet, ckpt_name=control_net_name) + else: + control_model = controlnet.control_model + compile_options = gen_compile_options(control_model) + control_model = control_model.to(load_device) + controlnet.control_model = oneflow_compile( + control_model, options=compile_options + ) return (controlnet,) else: print( @@ -448,7 +459,7 @@ def onediff_load_checkpoint( unet_graph_file = generate_graph_path(ckpt_name, modelpatcher.model) if model_optimizer is not None: - modelpatcher = model_optimizer(modelpatcher) + modelpatcher = model_optimizer(modelpatcher, ckpt_name=ckpt_name) modelpatcher.model.diffusion_model = compoile_unet( modelpatcher.model.diffusion_model, unet_graph_file) @@ -617,29 +628,13 @@ def INPUT_TYPES(s): def optimize_model(self, quantization_optimizer:OptimizerStrategy =None, deeppcache_optimizer:OptimizerStrategy=None): """Apply the optimization technique to the model.""" - def apply_optimizer(model, ckpt_name=""): - def set_compiled_options(module: DeployableModule, graph_file="unet"): - assert isinstance(module, DeployableModule) - compile_options = { - "graph_file": graph_file, - "graph_file_device": model_management.get_torch_device(), - } - module._deployable_module_options.update(compile_options) - + def apply_optimizer(model,*,ckpt_name=""): + if deeppcache_optimizer is not None: - model = deeppcache_optimizer.apply(model) - graph_file = generate_graph_path(ckpt_name, model.fast_deep_cache_unet._torch_module) - set_compiled_options(model.fast_deep_cache_unet, graph_file) - graph_file = generate_graph_path(ckpt_name, model.deep_cache_unet._torch_module) - set_compiled_options(model.deep_cache_unet, graph_file) + model = deeppcache_optimizer.apply(model, ckpt_name) if quantization_optimizer is not None: - model = quantization_optimizer.apply(model) - diff_model: DeployableModule = model.model.diffusion_model - graph_file = generate_graph_path(ckpt_name, model.model) - set_compiled_options(diff_model, graph_file) - quant_config = diff_model._deployable_module_quant_config - quant_config.cache_dir = os.path.dirname(graph_file) + model = quantization_optimizer.apply(model, ckpt_name) return model diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py b/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py index 4f32980b4..7865bc81b 100644 --- a/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py +++ b/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py @@ -1,3 +1,3 @@ from .deepcache_optimizer import DeepcacheOptimizerExecutor +from .optimizer_strategy import OptimizerStrategy from .quantization_optimizer import OnelineQuantizationOptimizerExecutor -from .optimizer_strategy import OptimizerStrategy \ No newline at end of file diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py b/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py index 2dec5e12a..9fcc8ef7a 100644 --- a/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py +++ b/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py @@ -4,7 +4,8 @@ from comfy.model_patcher import ModelPatcher from ...utils.deep_cache_speedup import deep_cache_speedup -from .optimizer_strategy import OptimizerStrategy +from ...utils.graph_path import generate_graph_path +from .optimizer_strategy import OptimizerStrategy, set_compiled_options @dataclass @@ -24,8 +25,8 @@ def apply(self, model): return model @apply.register(ModelPatcher) - def _(self, model): - return deep_cache_speedup( + def _(self, model, ckpt_name=""): + model = deep_cache_speedup( model=model, use_graph=True, cache_interval=self.cache_interval, @@ -35,6 +36,10 @@ def _(self, model): end_step=self.end_step, use_oneflow_deepcache_speedup_modelpatcher=False, )[0] - - - + graph_file = generate_graph_path( + ckpt_name, model.fast_deep_cache_unet._torch_module + ) + set_compiled_options(model.fast_deep_cache_unet, graph_file) + graph_file = generate_graph_path(ckpt_name, model.deep_cache_unet._torch_module) + set_compiled_options(model.deep_cache_unet, graph_file) + return model diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py b/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py index e060841a3..805fb1ba1 100644 --- a/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py +++ b/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py @@ -1,13 +1,24 @@ from abc import ABC, abstractmethod +from comfy import model_management from comfy.model_patcher import ModelPatcher +from onediff.infer_compiler.with_oneflow_compile import DeployableModule + class OptimizerStrategy(ABC): """Interface for optimization strategies.""" @abstractmethod - def apply(self, model: ModelPatcher): + def apply(self, model: ModelPatcher, ckpt_name=""): """Apply the optimization strategy to the model.""" pass + +def set_compiled_options(module: DeployableModule, graph_file="unet"): + assert isinstance(module, DeployableModule) + compile_options = { + "graph_file": graph_file, + "graph_file_device": model_management.get_torch_device(), + } + module._deployable_module_options.update(compile_options) diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py b/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py index 155e64042..c5ec4daaa 100644 --- a/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py +++ b/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py @@ -1,21 +1,23 @@ +import os from dataclasses import dataclass from functools import singledispatchmethod from typing import Any, Dict import torch import torch.nn as nn +from comfy.controlnet import ControlNet from comfy.model_patcher import ModelPatcher from onediff_quant.quantization import QuantizationConfig from onediff_quant.quantization.module_operations import get_sub_module -from onediff_quant.quantization.quantize_calibrators import ( - QuantizationMetricsCalculator, -) +from onediff_quant.quantization.quantize_calibrators import \ + QuantizationMetricsCalculator from onediff_quant.quantization.quantize_config import Metric from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.with_oneflow_compile import DeployableModule -from .optimizer_strategy import OptimizerStrategy +from ...utils.graph_path import generate_graph_path +from .optimizer_strategy import OptimizerStrategy, set_compiled_options def get_torch_model(diff_model): @@ -47,15 +49,19 @@ def __init__( def calibrate(self, *args: Any, **kwargs: Any) -> Dict[str, Dict[str, float]]: if self.conv_percentage == 1.0 and self.linear_percentage == 1.0: # only_use_compute_density - costs_calibrate_info = self.compute_quantization_costs(args, kwargs, module_selector=self.module_selector) + costs_calibrate_info = self.compute_quantization_costs( + args, kwargs, module_selector=self.module_selector + ) costs_calibrate_info = self.apply_filter(costs_calibrate_info) - self.save_quantization_status(costs_calibrate_info, "quantization_stats.json") + self.save_quantization_status( + costs_calibrate_info, "quantization_stats.json" + ) return costs_calibrate_info - + calibrate_info = self.calibrate_all_layers( args, kwargs, module_selector=self.module_selector ) - + selected_model = self.module_selector(self.model) # Initialize max and min values, as well as lists for linear and convolutional layer data @@ -101,12 +107,12 @@ class OnelineQuantizationOptimizerExecutor(OptimizerStrategy): linear_compute_density_threshold: int = 300 @singledispatchmethod - def apply(self, model): + def apply(self, model, *args, **kwargs): print(f"{type(self).__name__}.apply() not implemented for {type(model)}") return model @apply.register(ModelPatcher) - def _(self, model: ModelPatcher): + def _(self, model: ModelPatcher, ckpt_name=""): quant_config = QuantizationConfig.from_settings( quantize_conv=True, quantize_linear=True, @@ -130,6 +136,38 @@ def _(self, model: ModelPatcher): diff_model = oneflow_compile(diff_model) diff_model.apply_online_quant(quant_config) model.model.diffusion_model = diff_model - return model + graph_file = generate_graph_path(ckpt_name, model.model) + quant_config.cache_dir = os.path.dirname(graph_file) + set_compiled_options(diff_model, graph_file) + quant_config = diff_model._deployable_module_quant_config + return model + @apply.register(ControlNet) + def _(self, model, ckpt_name=""): + quant_config = QuantizationConfig.from_settings( + quantize_conv=True, + quantize_linear=True, + bits=8, + conv_mae_threshold=0.9, + linear_mae_threshold=0.9, + plot_calibrate_info=True, + conv_compute_density_threshold=self.conv_compute_density_threshold, + linear_compute_density_threshold=self.linear_compute_density_threshold, + ) + control_model = model.control_model + quant_config.quantization_calculator = SubQuantizationPercentileCalculator( + control_model, + quant_config, + cache_key="ControlNet", + conv_percentage=self.conv_percentage / 100, + linear_percentage=self.linear_percentage / 100, + ) + graph_file = generate_graph_path(ckpt_name, control_model) + quant_config.cache_dir = os.path.dirname(graph_file) + if not isinstance(control_model, DeployableModule): + control_model = oneflow_compile(control_model) + control_model.apply_online_quant(quant_config) + set_compiled_options(control_model, graph_file) + model.control_model = control_model + return model From f347f5f503aac98741158a41781281fbe5113f2e Mon Sep 17 00:00:00 2001 From: ccsuu Date: Sun, 7 Apr 2024 15:05:00 +0000 Subject: [PATCH 08/11] refine --- .../oneflow/deployable_module.py | 27 +- .../infer_compiler/with_oneflow_compile.py | 524 ------------------ 2 files changed, 26 insertions(+), 525 deletions(-) delete mode 100644 src/onediff/infer_compiler/with_oneflow_compile.py diff --git a/src/onediff/infer_compiler/oneflow/deployable_module.py b/src/onediff/infer_compiler/oneflow/deployable_module.py index 73119060b..fd5ea75e7 100644 --- a/src/onediff/infer_compiler/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/oneflow/deployable_module.py @@ -8,6 +8,7 @@ from ..utils.log_utils import logger from ..utils.param_utils import parse_device, check_device from ..utils.graph_management_utils import graph_file_management +from ..utils.online_quantization_utils import quantize_and_deploy_wrapper from ..deployable_module import DeployableModule from .utils import handle_deployable_exception, get_mixed_dual_module, get_oneflow_graph @@ -28,6 +29,7 @@ def __init__( self._deployable_module_use_graph = use_graph self._deployable_module_enable_dynamic = dynamic self._deployable_module_options = options + self._deployable_module_quant_config = None self._deployable_module_dpl_graph = None self._is_raw_deployable_module = True self._load_graph_first_run = True @@ -44,6 +46,7 @@ def from_existing(cls, existing_module, use_graph=None, dynamic=None, options=No instance._deployable_module_input_count = ( existing_module._deployable_module_input_count ) + instance._deployable_module_quant_config = existing_module._deployable_module_quant_config return instance @@ -82,7 +85,8 @@ def apply_model(self, *args, **kwargs): *args, **kwargs ) return output - + + @quantize_and_deploy_wrapper @input_output_processor @handle_deployable_exception @graph_file_management @@ -171,3 +175,24 @@ def _clear_old_graph(self): def get_graph_file(self): return self._deployable_module_options.get("graph_file", None) + + def apply_online_quant(self, quant_config): + """ + Applies the provided quantization configuration for online use. + + Args: + quant_config (QuantizationConfig): The quantization configuration to apply. + + Example: + >>> from onediff_quant.quantization import QuantizationConfig + >>> quant_config = QuantizationConfig.from_settings( + ... quantize_conv=True, + ... quantize_linear=True, + ... conv_mae_threshold=0.005, + ... linear_mae_threshold=0.005, + ... conv_compute_density_threshold=300, + ... linear_compute_density_threshold=100, + ... cache_dir=args.cache_dir) + >>> model.apply_online_quant(quant_config) + """ + self._deployable_module_quant_config = quant_config \ No newline at end of file diff --git a/src/onediff/infer_compiler/with_oneflow_compile.py b/src/onediff/infer_compiler/with_oneflow_compile.py deleted file mode 100644 index 645b08e87..000000000 --- a/src/onediff/infer_compiler/with_oneflow_compile.py +++ /dev/null @@ -1,524 +0,0 @@ -import os -import types -import torch -import oneflow as flow -from oneflow.utils.tensor import to_torch -from typing import Any -from functools import wraps -from itertools import chain -from .transform.manager import transform_mgr -from .transform.custom_transform import set_default_registry -from .transform.builtin_transform import torch2oflow, reverse_proxy_class -from .utils.oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled -from .utils.args_tree_util import input_output_processor -from .utils.log_utils import logger -from .utils.cost_util import cost_cnt -from .utils.param_utils import parse_device, check_device -from .utils.graph_management_utils import graph_file_management -from .utils.online_quantization_utils import quantize_and_deploy_wrapper - -class DualModule(torch.nn.Module): - def __init__(self, torch_module, oneflow_module): - torch.nn.Module.__init__(self) - object.__setattr__(self, "_torch_module", torch_module) - object.__setattr__(self, "_oneflow_module", oneflow_module) - object.__setattr__(self, "_modules", torch_module._modules) - object.__setattr__(self, "_parameters", torch_module._parameters) - object.__setattr__(self, "_buffers", torch_module._buffers) - - @property - def oneflow_module(self): - if self._oneflow_module is not None: - return self._oneflow_module - - logger.debug(f"Convert {type(self._torch_module)} ...") - self._oneflow_module = torch2oflow(self._torch_module) - logger.debug(f"Convert {type(self._torch_module)} done!") - return self._oneflow_module - - @oneflow_module.deleter - def oneflow_module(self): - if self._oneflow_module: - del self._oneflow_module - setattr(self, "_oneflow_module", None) - - def to(self, *args, **kwargs): - if oneflow_exec_mode_enabled(): - self._oneflow_module.to(*args, **kwargs) - else: - if self._oneflow_module is not None: - of_args = [torch2oflow(v) for v in args] - of_kwargs = {k: torch2oflow(v) for k, v in kwargs.items()} - self._oneflow_module.to(*of_args, **of_kwargs) - self._torch_module_to_with_check(*args, **kwargs) - else: - self._torch_module.to(*args, **kwargs) - - def _torch_module_to_with_check(self, *args, **kwargs): - def _align_tensor(torch_module, oneflow_module): - oneflow_tensor_list = set( - [x for x, _ in oneflow_module.named_parameters()] - + [x for x, _ in oneflow_module.named_buffers()] - ) - for name, tensor in chain.from_iterable( - [torch_module.named_parameters(), torch_module.named_buffers(),] - ): - if name not in oneflow_tensor_list: - tensor.data = tensor.to(*args, **kwargs) - else: - oneflow_tensor = oneflow_module.get_parameter(name) - if oneflow_tensor is None: - tensor.data = tensor.to(*args, **kwargs) - elif tensor.data_ptr() != oneflow_tensor.data_ptr(): - tensor.data = to_torch(oneflow_tensor.data) - - oneflow_module_list = set([x for x, _ in self._oneflow_module.named_modules()]) - for name, module in self._torch_module.named_modules(): - if name not in oneflow_module_list: - module.to(*args, **kwargs) - else: - _align_tensor(module, self._oneflow_module.get_submodule(name)) - - def __getattr__(self, name): - if name == "_torch_module" or name == "_oneflow_module": - return super().__getattribute__(name) - - torch_attr = getattr(self._torch_module, name) - oneflow_attr = ( - None - if self._oneflow_module is None - else getattr(self._oneflow_module, name) - ) - - if isinstance(torch_attr, torch.nn.ModuleList): - if oneflow_attr is None: - oneflow_attr = flow.nn.ModuleList([None] * len(torch_attr)) - return DualModuleList(torch_attr, oneflow_attr) - - elif isinstance(torch_attr, torch.nn.Module): - return get_mixed_dual_module(torch_attr.__class__)(torch_attr, oneflow_attr) - else: - return oneflow_attr if oneflow_exec_mode_enabled() else torch_attr - - def __setattr__(self, name: str, value: Any) -> None: - if name in ["_torch_module", "_oneflow_module"]: - super().__setattr__(name, value) - else: # TODO: aviod memory up when set attr - if self._oneflow_module is not None: - v = torch2oflow(value) - if isinstance(v, flow.Tensor): - obj = getattr(self._oneflow_module, name) - obj.copy_(v) - else: - setattr(self._oneflow_module, name, v) - setattr(self._torch_module, name, value) - - def extra_repr(self) -> str: - return self._torch_module.extra_repr() - - -class DualModuleList(torch.nn.ModuleList): - def __init__(self, torch_modules, oneflow_modules): - super().__init__() - assert len(torch_modules) == len(oneflow_modules) - self._torch_modules = torch_modules - self._oneflow_modules = oneflow_modules - dual_modules = [] - for torch_module, oneflow_module in zip( - self._torch_modules, self._oneflow_modules - ): - dual_modules.append( - get_mixed_dual_module(torch_module.__class__)( - torch_module, oneflow_module - ) - ) - # clear self._modules since `self._torch_modules = torch_modules` will append a module to self._modules - self._modules.clear() - self += dual_modules - - def __setitem__(self, idx: int, module: DualModule): - idx = self._get_abs_string_index(idx) - setattr(self._torch_modules, str(idx), module._torch_module) - setattr(self._oneflow_modules, str(idx), module._oneflow_module) - return setattr(self, str(idx), module) - - def __setattr__(self, key, value): - if key in ("_torch_modules", "_oneflow_modules"): - return object.__setattr__(self, key, value) - if isinstance(value, DualModule): - setattr(self._torch_modules, key, value._torch_module) - setattr(self._oneflow_modules, key, value._oneflow_module) - else: - setattr(self._torch_modules, key, value) - value = torch2oflow(value) - setattr(self._oneflow_modules, key, value) - return object.__setattr__(self, key, value) - - -def get_mixed_dual_module(module_cls): - if issubclass(module_cls, DualModule) and "MixedDualModule" in module_cls.__name__: - return module_cls - - class MixedDualModule(DualModule, module_cls): - def __init__(self, torch_module, oneflow_module): - while isinstance(torch_module, DualModule): - torch_module = torch_module._torch_module - DualModule.__init__(self, torch_module, oneflow_module) - - def _get_name(self) -> str: - return f"{self.__class__.__name__}(of {module_cls.__name__})" - - return MixedDualModule - - -@torch2oflow.register -def _(mod: DualModule, verbose=False): - return torch2oflow(mod._torch_module, verbose) - - -def handle_deployable_exception(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - if transform_mgr.debug_mode: - return func(self, *args, **kwargs) - else: - try: - return func(self, *args, **kwargs) - except Exception as e: - logger.error(f"Exception in {func.__name__}: {e=}") - logger.warning("Recompile oneflow module ...") - del self._deployable_module_model.oneflow_module - self._deployable_module_dpl_graph = None - return func(self, *args, **kwargs) - - return wrapper - - -class DeployableModule(torch.nn.Module): - def __init__( - self, torch_module, oneflow_module, use_graph=True, dynamic=True, options={}, - ): - torch.nn.Module.__init__(self) - object.__setattr__( - self, - "_deployable_module_model", - get_mixed_dual_module(torch_module.__class__)(torch_module, oneflow_module), - ) - object.__setattr__(self, "_modules", torch_module._modules) - object.__setattr__(self, "_torch_module", torch_module) - self._deployable_module_use_graph = use_graph - self._deployable_module_enable_dynamic = dynamic - self._deployable_module_options = options - self._deployable_module_dpl_graph = None - self._deployable_module_quant_config = None - self._is_raw_deployable_module = True - self._load_graph_first_run = True - - @classmethod - def from_existing(cls, existing_module, use_graph=None, dynamic=None, options=None): - torch_module = existing_module._deployable_module_model._torch_module - oneflow_module = existing_module._deployable_module_model._oneflow_module - instance = cls(torch_module, oneflow_module, use_graph, dynamic, options) - instance._deployable_module_dpl_graph = ( - existing_module._deployable_module_dpl_graph if use_graph else None - ) - instance._load_graph_first_run = existing_module._load_graph_first_run - instance._deployable_module_input_count = ( - existing_module._deployable_module_input_count - ) - - instance._deployable_module_quant_config = existing_module._deployable_module_quant_config - return instance - - def get_graph(self): - if self._deployable_module_dpl_graph is not None: - return self._deployable_module_dpl_graph - if "size" in self._deployable_module_options: - size = self._deployable_module_options["size"] - else: - size = 9 - self._deployable_module_dpl_graph = get_oneflow_graph( - self._deployable_module_model.oneflow_module, - size, - self._deployable_module_enable_dynamic, - ) - # Enabel debug mode - if transform_mgr.debug_mode: - self._deployable_module_dpl_graph.debug(0) - if "debug" in self._deployable_module_options: - self._deployable_module_dpl_graph.debug( - self._deployable_module_options["debug"] - ) - return self._deployable_module_dpl_graph - - - @input_output_processor - @handle_deployable_exception - @graph_file_management - def apply_model(self, *args, **kwargs): - if self._deployable_module_use_graph: - dpl_graph = self.get_graph() - with oneflow_exec_mode(): - output = dpl_graph(*args, **kwargs) - else: - with oneflow_exec_mode(): - output = self._deployable_module_model.oneflow_module.apply_model( - *args, **kwargs - ) - return output - - @quantize_and_deploy_wrapper - @input_output_processor - @handle_deployable_exception - @graph_file_management - def __call__(self, *args, **kwargs): - if self._deployable_module_use_graph: - dpl_graph = self.get_graph() - with oneflow_exec_mode(): - output = dpl_graph(*args, **kwargs) - else: - with oneflow_exec_mode(): - output = self._deployable_module_model.oneflow_module(*args, **kwargs) - return output - - def to(self, *args, **kwargs): - if self._deployable_module_dpl_graph is None: - self._deployable_module_model.to(*args, **kwargs) - return self - - # assert the target device is same as graph device - target_device = parse_device(args, kwargs) - if ( - target_device is not None - and len(self._deployable_module_dpl_graph._blocks) > 0 - ): - current_device = next(self._deployable_module_dpl_graph._state()).device - if not check_device(current_device, target_device): - raise RuntimeError( - f"After graph built, the device of graph can't be modified, current device: {current_device}, target device: {target_device}" - ) - self._deployable_module_model.to(*args, **kwargs) - return self - - # TODO(): Just for transformers VAE decoder - @input_output_processor - @handle_deployable_exception - @graph_file_management - def decode(self, *args, **kwargs): - if self._deployable_module_use_graph: - - def _build(graph, *args, **kwargs): - return graph.model.decode(*args, **kwargs) - - dpl_graph = self.get_graph() - dpl_graph.build = types.MethodType(_build, dpl_graph) - with oneflow_exec_mode(): - output = dpl_graph(*args, **kwargs) - else: - with oneflow_exec_mode(): - output = self._deployable_module_model.oneflow_module.decode( - *args, **kwargs - ) - return output - - def __getattr__(self, name): - return getattr(self._deployable_module_model, name) - - def load_graph(self, file_path, device=None, run_warmup=True): - self.get_graph().load_graph(file_path, device, run_warmup) - - def save_graph(self, file_path): - self.get_graph().save_graph(file_path) - - def extra_repr(self) -> str: - return self._deployable_module_model.extra_repr() - - def set_graph_file(self, file_path: str) -> None: - """ Sets the path of the graph file. - - If the new file path is different from the old one, clears old graph data. - - Args: - `file_path` (str): The path of the graph file. - """ - old_file_path = self.get_graph_file() - if file_path and old_file_path == file_path: - return - compiled_options = self._deployable_module_options - compiled_options["graph_file"] = file_path - - self._clear_old_graph() - - def _clear_old_graph(self): - self._load_graph_first_run = True - self._deployable_module_dpl_graph = None - del self._deployable_module_model.oneflow_module - - def get_graph_file(self): - return self._deployable_module_options.get("graph_file", None) - - - def apply_online_quant(self, quant_config): - """ - Applies the provided quantization configuration for online use. - - Args: - quant_config (QuantizationConfig): The quantization configuration to apply. - - Example: - >>> from onediff_quant.quantization import QuantizationConfig - >>> quant_config = QuantizationConfig.from_settings( - ... quantize_conv=True, - ... quantize_linear=True, - ... conv_mae_threshold=0.005, - ... linear_mae_threshold=0.005, - ... conv_compute_density_threshold=300, - ... linear_compute_density_threshold=100, - ... cache_dir=args.cache_dir) - >>> model.apply_online_quant(quant_config) - """ - self._deployable_module_quant_config = quant_config - - - -class OneflowGraph(flow.nn.Graph): - @flow.nn.Graph.with_dynamic_input_shape() - def __init__(self, model): - super().__init__(enable_get_runtime_state_dict=True) - self.model = model - logger.info(f"Building a graph for {model.__class__.__name__} ...") - # self.config.enable_cudnn_conv_heuristic_search_algo(False) - self.config.allow_fuse_add_to_output(True) - - def build(self, *args, **kwargs): - return self.model(*args, **kwargs) - - @cost_cnt(transform_mgr.debug_mode) - def load_graph(self, file_path, device=None, run_warmup=True): - state_dict = flow.load(file_path) - self.graph_state_dict = state_dict # used for OneflowGraph.save_graph - if device is not None: - state_dict = flow.nn.Graph.runtime_state_dict_to(state_dict, device) - self.load_runtime_state_dict(state_dict, warmup_with_run=run_warmup) - - @cost_cnt(transform_mgr.debug_mode) - def save_graph(self, file_path): - if hasattr(self, "graph_state_dict"): - flow.save(self.graph_state_dict, file_path) - return - - state_dict = self.runtime_state_dict() - - import oneflow.framework.args_tree as args_tree - - def disabled_dataclass(value): - return False - - original_is_dataclass = args_tree._is_dataclass - args_tree._is_dataclass = disabled_dataclass - - import dataclasses - - def reverse_dataclass(value): - if dataclasses.is_dataclass(value): - return reverse_proxy_class(type(value))(**value) - else: - return value - - for name, rsd in state_dict.items(): - output = state_dict[name]["outputs_original"] - out_tree = args_tree.ArgsTree((output, None), False) - # dataclass type needs to be reversed to torch type to avoid saving error. - out = out_tree.map_leaf(reverse_dataclass) - state_dict[name]["outputs_original"] = out[0] - - args_tree._is_dataclass = original_is_dataclass - - flow.save(state_dict, file_path) - - -def get_oneflow_graph(model, size=9, dynamic_graph=True): - g = OneflowGraph(model) - g._dynamic_input_graph_cache.set_cache_size(size) - g._dynamic_input_graph_cache.enable_shared(dynamic_graph) - return g - - -def state_dict_hook(module, state_dict, prefix, local_metadata): - pytorch_key_prefix = "_deployable_module_model._torch_module." - new_state_dict = type(state_dict)() - for k, v in state_dict.items(): - # _deployable_module_model._torch_module.out.2.weight => out.2.weight - if k.startswith(pytorch_key_prefix): - new_k = k[len(pytorch_key_prefix) :] - new_state_dict[new_k] = v - else: - new_state_dict[k] = v - return new_state_dict - - -# Return a DeployableModule that using module_cls as it's parent class. -def get_mixed_deployable_module(module_cls): - class MixedDeployableModule(DeployableModule, module_cls): - def __init__( - self, torch_module, oneflow_module, use_graph=True, dynamic=True, options={} - ): - DeployableModule.__init__( - self, torch_module, oneflow_module, use_graph, dynamic, options - ) - self._is_raw_deployable_module = False - - @classmethod - def from_existing( - cls, existing_module, use_graph=None, dynamic=True, options=None - ): - torch_module = existing_module._deployable_module_model._torch_module - oneflow_module = existing_module._deployable_module_model._oneflow_module - instance = cls(torch_module, oneflow_module, use_graph, dynamic, options) - instance._deployable_module_dpl_graph = ( - existing_module._deployable_module_dpl_graph if use_graph else None - ) - return instance - - def _get_name(self): - return f"{self.__class__.__name__}(of {module_cls.__name__})" - - return MixedDeployableModule - - -def oneflow_compile( - torch_module: torch.nn.Module, *, use_graph=True, dynamic=True, options={}, -) -> DeployableModule: - """ - Transform a torch nn.Module to oneflow.nn.Module, then optimize it with oneflow.nn.Graph. - Args: - model (torch.nn.Module): Module to optimize - use_graph (bool): Whether to optimize with oneflow.nn.Graph - dynamic (bool): When this is True, we will generate one graph and reuse it to avoid recompilations when - input shape change. This may not always work as some operations/optimizations break the contition of - reusing. When this is False, we will generate a graph for each new input shape, and will always specialize. - By default (True). - options (dict): A dictionary of options to pass to the compiler: - - 'debug' which config the nn.Graph debug level, default -1(no debug info), max 3(max debug info); - - 'size' which config the cache size when cache is enabled. Note that after onediff v0.12, cache is default disabled. - - 'graph_file' (None) generates a compilation cache file. If the file exists, loading occurs; if not, the compilation result is saved after the first run. - - 'graph_file_device' (None) sets the device for the graph file, default None. If set, the compilation result will be converted to the specified device. - """ - - set_default_registry() - - def wrap_module(module): - if isinstance(module, DeployableModule): - assert not module._is_raw_deployable_module - return module.__class__.from_existing(module, use_graph, dynamic, options) - else: - return get_mixed_deployable_module(module.__class__)( - module, None, use_graph, dynamic, options - ) - - model = wrap_module(torch_module) - assert isinstance(model, DeployableModule) - assert isinstance(model, torch_module.__class__) - model._register_state_dict_hook(state_dict_hook) - - return model \ No newline at end of file From b0db1bb2aa2bddafc0669b8b38a52d2e14a64d08 Mon Sep 17 00:00:00 2001 From: ccsuu Date: Mon, 8 Apr 2024 04:29:38 +0000 Subject: [PATCH 09/11] Delete redundant changes --- onediff_comfy_nodes/__init__.py | 8 +- onediff_comfy_nodes/_nodes.py | 200 +----------------- .../modules/optimizer_strategy/__init__.py | 3 - .../optimizer_strategy/deepcache_optimizer.py | 45 ---- .../optimizer_strategy/optimizer_strategy.py | 24 --- .../quantization_optimizer.py | 173 --------------- .../utils/loader_sample_tools.py | 5 +- .../examples/text_to_image_online_quant.py | 39 ++-- 8 files changed, 34 insertions(+), 463 deletions(-) delete mode 100644 onediff_comfy_nodes/modules/optimizer_strategy/__init__.py delete mode 100644 onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py delete mode 100644 onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py delete mode 100644 onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py diff --git a/onediff_comfy_nodes/__init__.py b/onediff_comfy_nodes/__init__.py index cdc8d74d5..f0d112d5f 100644 --- a/onediff_comfy_nodes/__init__.py +++ b/onediff_comfy_nodes/__init__.py @@ -15,7 +15,7 @@ BatchSizePatcher, ) from ._compare_node import CompareModel, ShowImageDiff -from ._nodes import OneDiffModelOptimizer, OneDiffDeepcacheOptimizer, OneDiffOnlineQuantizationOptimizer + NODE_CLASS_MAPPINGS = { "ModelSpeedup": ModelSpeedup, @@ -32,9 +32,6 @@ "OneDiffControlNetLoader": OneDiffControlNetLoader, "OneDiffDeepCacheCheckpointLoaderSimple": OneDiffDeepCacheCheckpointLoaderSimple, "BatchSizePatcher": BatchSizePatcher, - "OneDiffModelOptimizer": OneDiffModelOptimizer, - "OneDiffDeepcacheOptimizer": OneDiffDeepcacheOptimizer, - "OneDiffOnlineQuantizationOptimizer": OneDiffOnlineQuantizationOptimizer, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -52,9 +49,6 @@ "OneDiffControlNetLoader": "Load ControlNet Model - OneDiff", "OneDiffDeepCacheCheckpointLoaderSimple": "Load Checkpoint - OneDiff DeepCache", "BatchSizePatcher": "Batch Size Patcher", - "OneDiffModelOptimizer": "Model Optimizer - OneDiff", - "OneDiffDeepcacheOptimizer": "DeepCache Optimizer - OneDiff", - "OneDiffOnlineQuantizationOptimizer": "Online Quantization Optimizer - OneDiff", } diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index a1d774ecb..6a11e8340 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -375,21 +375,13 @@ def deep_cache_convert( from nodes import CheckpointLoaderSimple, ControlNetLoader from comfy.controlnet import ControlLora, ControlNet from .modules.onediff_controlnet import OneDiffControlLora -from .modules.optimizer_strategy import OptimizerStrategy, DeepcacheOptimizerExecutor, OnelineQuantizationOptimizerExecutor -class OneDiffControlNetLoader(ControlNetLoader): - @classmethod - def INPUT_TYPES(s): - ret = super().INPUT_TYPES() - ret.update({"optional": { - "model_optimizer": ("MODEL_OPTIMIZER",),} - }) - return ret +class OneDiffControlNetLoader(ControlNetLoader): CATEGORY = "OneDiff/Loaders" FUNCTION = "onediff_load_controlnet" - def onediff_load_controlnet(self, control_net_name, model_optimizer=None): + def onediff_load_controlnet(self, control_net_name): controlnet = super().load_controlnet(control_net_name)[0] load_device = model_management.get_torch_device() @@ -406,15 +398,12 @@ def gen_compile_options(model): ) return (controlnet,) elif isinstance(controlnet, ControlNet): - if model_optimizer: - controlnet = model_optimizer(controlnet, ckpt_name=control_net_name) - else: - control_model = controlnet.control_model - compile_options = gen_compile_options(control_model) - control_model = control_model.to(load_device) - controlnet.control_model = oneflow_compile( - control_model, options=compile_options - ) + control_model = controlnet.control_model + compile_options = gen_compile_options(control_model) + control_model = control_model.to(load_device) + controlnet.control_model = oneflow_compile( + control_model, options=compile_options + ) return (controlnet,) else: print( @@ -430,9 +419,6 @@ def INPUT_TYPES(s): "required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), "vae_speedup": (["disable", "enable"],), - }, - "optional": { - "model_optimizer": ("MODEL_OPTIMIZER",), } } @@ -440,22 +426,17 @@ def INPUT_TYPES(s): FUNCTION = "onediff_load_checkpoint" def onediff_load_checkpoint( - self, ckpt_name, vae_speedup="disable", output_vae=True, output_clip=True, model_optimizer: callable=None, + self, ckpt_name, vae_speedup, output_vae=True, output_clip=True ): # CheckpointLoaderSimple.load_checkpoint modelpatcher, clip, vae = self.load_checkpoint( ckpt_name, output_vae, output_clip ) - unet_graph_file = generate_graph_path(ckpt_name, modelpatcher.model) - if model_optimizer is not None: - modelpatcher = model_optimizer(modelpatcher, ckpt_name=ckpt_name) - modelpatcher.model.diffusion_model = compoile_unet( - modelpatcher.model.diffusion_model, unet_graph_file) - + modelpatcher.model.diffusion_model, unet_graph_file + ) modelpatcher.model._register_state_dict_hook(state_dict_hook) - if vae_speedup == "enable": file_path = generate_graph_path(ckpt_name, vae.first_stage_model) vae.first_stage_model = oneflow_compile( @@ -472,165 +453,6 @@ def onediff_load_checkpoint( return modelpatcher, clip, vae -class OneDiffDeepcacheOptimizer: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "cache_interval": ( - "INT", - { - "default": 3, - "min": 1, - "max": 1000, - "step": 1, - "display": "number", - }, - ), - "cache_layer_id": ( - "INT", - {"default": 0, "min": 0, "max": 12, "step": 1, "display": "number"}, - ), - "cache_block_id": ( - "INT", - {"default": 1, "min": 0, "max": 12, "step": 1, "display": "number"}, - ), - "start_step": ( - "INT", - { - "default": 0, - "min": 0, - "max": 1000, - "step": 1, - "display": "number", - }, - ), - "end_step": ( - "INT", - {"default": 1000, "min": 0, "max": 1000, "step": 0.1}, - ), - }, - } - - CATEGORY = "OneDiff/Optimizer" - RETURN_TYPES = ("DeepPCacheOptimizer",) - FUNCTION = "apply" - - def apply( - self, - cache_interval=3, - cache_layer_id=0, - cache_block_id=1, - start_step=0, - end_step=1000, - ): - return ( - DeepcacheOptimizerExecutor( - cache_interval=cache_interval, - cache_layer_id=cache_layer_id, - cache_block_id=cache_block_id, - start_step=start_step, - end_step=end_step, - ), - ) - - -class OneDiffOnlineQuantizationOptimizer: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "quantized_conv_percentage": ( - "INT", - { - "default": 70, - "min": 0, # Minimum value - "max": 100, # Maximum value - "step": 1, # Slider's step - "display": "slider", # Cosmetic only: display as "number" or "slider" - }, - ), - "quantized_linear_percentage": ( - "INT", - { - "default": 80, - "min": 0, # Minimum value - "max": 100, # Maximum value - "step": 1, # Slider's step - "display": "slider", # Cosmetic only: display as "number" or "slider" - }, - ), - "conv_compute_density_threshold":( - "INT", - { - "default": 100, - "min": 0, # Minimum value - "max": 2000, # Maximum value - "step": 10, # Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - }, - ), - "linear_compute_density_threshold":( - "INT", - { - "default": 300, - "min": 0, # Minimum value - "max": 2000, # Maximum value - "step": 10, # Slider's step - "display": "number", # Cosmetic only: display as "number" or "slider" - }, - ) - }, - } - - CATEGORY = "OneDiff/Optimizer" - RETURN_TYPES = ("QuantizationOptimizer",) - FUNCTION = "apply" - - def apply(self, quantized_conv_percentage=0, quantized_linear_percentage=0, conv_compute_density_threshold=0, linear_compute_density_threshold=0): - return ( - OnelineQuantizationOptimizerExecutor( - conv_percentage=quantized_conv_percentage, - linear_percentage=quantized_linear_percentage, - conv_compute_density_threshold = conv_compute_density_threshold, - linear_compute_density_threshold = linear_compute_density_threshold, - ), - ) - - -class OneDiffModelOptimizer: - """Main class responsible for optimizing models.""" - - @classmethod - def INPUT_TYPES(s): - return { - "required": {}, - "optional": { - "quantization_optimizer": ("QuantizationOptimizer",), - "deeppcache_optimizer": ("DeepPCacheOptimizer",), - }, - } - - CATEGORY = "OneDiff/Optimization" - RETURN_TYPES = ("MODEL_OPTIMIZER",) - FUNCTION = "optimize_model" - - def optimize_model(self, quantization_optimizer:OptimizerStrategy =None, deeppcache_optimizer:OptimizerStrategy=None): - """Apply the optimization technique to the model.""" - - def apply_optimizer(model,*,ckpt_name=""): - - if deeppcache_optimizer is not None: - model = deeppcache_optimizer.apply(model, ckpt_name) - - if quantization_optimizer is not None: - model = quantization_optimizer.apply(model, ckpt_name) - - return model - - return (apply_optimizer,) - - class OneDiffDeepCacheCheckpointLoaderSimple(CheckpointLoaderSimple): @classmethod def INPUT_TYPES(s): diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py b/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py deleted file mode 100644 index 7865bc81b..000000000 --- a/onediff_comfy_nodes/modules/optimizer_strategy/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .deepcache_optimizer import DeepcacheOptimizerExecutor -from .optimizer_strategy import OptimizerStrategy -from .quantization_optimizer import OnelineQuantizationOptimizerExecutor diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py b/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py deleted file mode 100644 index 9fcc8ef7a..000000000 --- a/onediff_comfy_nodes/modules/optimizer_strategy/deepcache_optimizer.py +++ /dev/null @@ -1,45 +0,0 @@ -from dataclasses import dataclass -from functools import singledispatchmethod - -from comfy.model_patcher import ModelPatcher - -from ...utils.deep_cache_speedup import deep_cache_speedup -from ...utils.graph_path import generate_graph_path -from .optimizer_strategy import OptimizerStrategy, set_compiled_options - - -@dataclass -class DeepcacheOptimizerExecutor(OptimizerStrategy): - cache_interval: int = 3 - cache_layer_id: int = 0 - cache_block_id: int = 1 - start_step: int = 0 - end_step: int = 1000 - - @singledispatchmethod - def apply(self, model): - print( - "DeepcacheOptimizerExecutor.apply: not implemented for model type:", - type(model), - ) - return model - - @apply.register(ModelPatcher) - def _(self, model, ckpt_name=""): - model = deep_cache_speedup( - model=model, - use_graph=True, - cache_interval=self.cache_interval, - cache_layer_id=self.cache_layer_id, - cache_block_id=self.cache_block_id, - start_step=self.start_step, - end_step=self.end_step, - use_oneflow_deepcache_speedup_modelpatcher=False, - )[0] - graph_file = generate_graph_path( - ckpt_name, model.fast_deep_cache_unet._torch_module - ) - set_compiled_options(model.fast_deep_cache_unet, graph_file) - graph_file = generate_graph_path(ckpt_name, model.deep_cache_unet._torch_module) - set_compiled_options(model.deep_cache_unet, graph_file) - return model diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py b/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py deleted file mode 100644 index 805fb1ba1..000000000 --- a/onediff_comfy_nodes/modules/optimizer_strategy/optimizer_strategy.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod - -from comfy import model_management -from comfy.model_patcher import ModelPatcher - -from onediff.infer_compiler.with_oneflow_compile import DeployableModule - - -class OptimizerStrategy(ABC): - """Interface for optimization strategies.""" - - @abstractmethod - def apply(self, model: ModelPatcher, ckpt_name=""): - """Apply the optimization strategy to the model.""" - pass - - -def set_compiled_options(module: DeployableModule, graph_file="unet"): - assert isinstance(module, DeployableModule) - compile_options = { - "graph_file": graph_file, - "graph_file_device": model_management.get_torch_device(), - } - module._deployable_module_options.update(compile_options) diff --git a/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py b/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py deleted file mode 100644 index c5ec4daaa..000000000 --- a/onediff_comfy_nodes/modules/optimizer_strategy/quantization_optimizer.py +++ /dev/null @@ -1,173 +0,0 @@ -import os -from dataclasses import dataclass -from functools import singledispatchmethod -from typing import Any, Dict - -import torch -import torch.nn as nn -from comfy.controlnet import ControlNet -from comfy.model_patcher import ModelPatcher -from onediff_quant.quantization import QuantizationConfig -from onediff_quant.quantization.module_operations import get_sub_module -from onediff_quant.quantization.quantize_calibrators import \ - QuantizationMetricsCalculator -from onediff_quant.quantization.quantize_config import Metric - -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.with_oneflow_compile import DeployableModule - -from ...utils.graph_path import generate_graph_path -from .optimizer_strategy import OptimizerStrategy, set_compiled_options - - -def get_torch_model(diff_model): - if isinstance(diff_model, DeployableModule): - return diff_model._torch_module - else: - return diff_model - - -class SubQuantizationPercentileCalculator(QuantizationMetricsCalculator): - def __init__( - self, - model: torch.nn.Module, - config: QuantizationConfig, - cache_key: str = "", - module_selector: callable = lambda x: x, - *, - seed=1, - select=Metric.MAE.value, - conv_percentage=0.9, - linear_percentage=0.9, - ): - super().__init__(model, config, cache_key, module_selector, seed=seed) - self.select = select - self.conv_percentage = conv_percentage - self.linear_percentage = linear_percentage - - @torch.no_grad() - def calibrate(self, *args: Any, **kwargs: Any) -> Dict[str, Dict[str, float]]: - if self.conv_percentage == 1.0 and self.linear_percentage == 1.0: - # only_use_compute_density - costs_calibrate_info = self.compute_quantization_costs( - args, kwargs, module_selector=self.module_selector - ) - costs_calibrate_info = self.apply_filter(costs_calibrate_info) - self.save_quantization_status( - costs_calibrate_info, "quantization_stats.json" - ) - return costs_calibrate_info - - calibrate_info = self.calibrate_all_layers( - args, kwargs, module_selector=self.module_selector - ) - - selected_model = self.module_selector(self.model) - - # Initialize max and min values, as well as lists for linear and convolutional layer data - max_value, min_value = 2, -1 - linear_values, conv_values = [max_value, min_value], [max_value, min_value] - # Iterate through quantization information for each layer, extracting quantization values for linear and convolutional layers - for module_name, value_info in calibrate_info.items(): - module = get_sub_module(selected_model, module_name) - values_list = ( - linear_values if isinstance(module, nn.Linear) else conv_values - ) - values_list.append(float(value_info[self.select])) - - # Sort quantization values for linear and convolutional layers based on the selected evaluation metric - linear_values.sort() - conv_values.sort() - - # Calculate linear and convolutional thresholds - conv_threshold = conv_values[int((len(conv_values) - 1) * self.conv_percentage)] - linear_threshold = linear_values[ - int((len(linear_values) - 1) * self.linear_percentage) - ] - # print(f"Conv threshold: {conv_threshold}, Linear threshold: {linear_threshold}") - - if nn.Conv2d in self.config.module_settings: - self.config.module_settings[nn.Conv2d][self.select] = conv_threshold - - if nn.Linear in self.config.module_settings: - self.config.module_settings[nn.Linear][self.select] = linear_threshold - - # Apply filters and save quantization status information - calibrate_info = self.apply_filter(calibrate_info) - self.save_quantization_status(calibrate_info, "quantization_stats.json") - - return calibrate_info - - -@dataclass -class OnelineQuantizationOptimizerExecutor(OptimizerStrategy): - conv_percentage: int = 60 - linear_percentage: int = 70 - conv_compute_density_threshold: int = 100 - linear_compute_density_threshold: int = 300 - - @singledispatchmethod - def apply(self, model, *args, **kwargs): - print(f"{type(self).__name__}.apply() not implemented for {type(model)}") - return model - - @apply.register(ModelPatcher) - def _(self, model: ModelPatcher, ckpt_name=""): - quant_config = QuantizationConfig.from_settings( - quantize_conv=True, - quantize_linear=True, - bits=8, - conv_mae_threshold=0.9, - linear_mae_threshold=0.9, - plot_calibrate_info=True, - conv_compute_density_threshold=self.conv_compute_density_threshold, - linear_compute_density_threshold=self.linear_compute_density_threshold, - ) - diff_model = model.model.diffusion_model - torch_model = get_torch_model(diff_model) - quant_config.quantization_calculator = SubQuantizationPercentileCalculator( - torch_model, - quant_config, - cache_key="unet", - conv_percentage=self.conv_percentage / 100, - linear_percentage=self.linear_percentage / 100, - ) - if not isinstance(diff_model, DeployableModule): - diff_model = oneflow_compile(diff_model) - diff_model.apply_online_quant(quant_config) - model.model.diffusion_model = diff_model - - graph_file = generate_graph_path(ckpt_name, model.model) - quant_config.cache_dir = os.path.dirname(graph_file) - set_compiled_options(diff_model, graph_file) - quant_config = diff_model._deployable_module_quant_config - return model - - @apply.register(ControlNet) - def _(self, model, ckpt_name=""): - quant_config = QuantizationConfig.from_settings( - quantize_conv=True, - quantize_linear=True, - bits=8, - conv_mae_threshold=0.9, - linear_mae_threshold=0.9, - plot_calibrate_info=True, - conv_compute_density_threshold=self.conv_compute_density_threshold, - linear_compute_density_threshold=self.linear_compute_density_threshold, - ) - control_model = model.control_model - quant_config.quantization_calculator = SubQuantizationPercentileCalculator( - control_model, - quant_config, - cache_key="ControlNet", - conv_percentage=self.conv_percentage / 100, - linear_percentage=self.linear_percentage / 100, - ) - graph_file = generate_graph_path(ckpt_name, control_model) - quant_config.cache_dir = os.path.dirname(graph_file) - if not isinstance(control_model, DeployableModule): - control_model = oneflow_compile(control_model) - control_model.apply_online_quant(quant_config) - set_compiled_options(control_model, graph_file) - model.control_model = control_model - return model diff --git a/onediff_comfy_nodes/utils/loader_sample_tools.py b/onediff_comfy_nodes/utils/loader_sample_tools.py index a08f256c6..caf4fcb82 100644 --- a/onediff_comfy_nodes/utils/loader_sample_tools.py +++ b/onediff_comfy_nodes/utils/loader_sample_tools.py @@ -11,7 +11,6 @@ # onediff_comfy_nodes from .model_patcher import state_dict_hook -from onediff.infer_compiler.with_oneflow_compile import DeployableModule def compoile_unet(diffusion_model, graph_file): @@ -25,9 +24,7 @@ def compoile_unet(diffusion_model, graph_file): "graph_file": graph_file, "graph_file_device": load_device, } - if isinstance(diffusion_model, DeployableModule): - return diffusion_model - + diffusion_model = oneflow_compile( diffusion_model, use_graph=use_graph, options=compile_options, ) diff --git a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py index 6af04e13a..ee2f84556 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py +++ b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py @@ -1,38 +1,42 @@ -"""[Stable Diffusion V1.5 - Hugging Face Model Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5) +"""[SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) ## Performance Comparison -Updated on Tue 02 Apr 2024 +Updated on Mon 08 Apr 2024 -Timings for 50 steps at 512x512 -| Accelerator | Baseline (non-optimized) | OneDiff Quant(optimized) | Percentage improvement | -| ----------------------- | ------------------------ | ------------------------ | ---------------------- | -| NVIDIA GeForce RTX 3090 | 2.51 s | 0.92 s | ~63 % | +Timings for 30 steps at 1024x1024 +| Accelerator | Baseline (non-optimized) | OneDiff(optimized) | OneDiff Quant(optimized) | +| ----------------------- | ------------------------ | ------------------ | ------------------------ | +| NVIDIA GeForce RTX 3090 | 8.03 s | 4.44 s ( ~44.7%) | 3.34 s ( ~58.4%) | + +- torch {version: 2.2.1+cu121} +- oneflow {git_commit: 710818c, version: 0.9.1.dev20240406+cu121, enterprise: True} ## Install -1. [OneDiff Installation Guide](https://github.com/siliconflow/onediff/tree/main?tab=readme-ov-file#installation) +1. [OneDiff Installation Guide](https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#install-onediff-enterprise) 2. [OneDiffx Installation Guide](https://github.com/siliconflow/onediff/tree/main/onediff_diffusers_extensions#install-and-setup) ## Usage: +> onediff/onediff_diffusers_extensions/examples/text_to_image_online_quant.py ```shell # Baseline (non-optimized) $ python text_to_image_online_quant.py \ - --model_id /share_nfs/hf_models/stable-diffusion-v1-5 \ + --model_id /share_nfs/hf_models/stable-diffusion-xl-base-1.0 \ --seed 1 \ - --backend torch --height 512 --width 512 --output_file sd-v1-5_torch.png + --backend torch --height 1024 --width 1024 --output_file sdxl_torch.png ``` ```shell # OneDiff Quant(optimized) $ python text_to_image_online_quant.py \ - --model_id /share_nfs/hf_models/stable-diffusion-v1-5 \ + --model_id /share_nfs/hf_models/stable-diffusion-xl-base-1.0 \ --seed 1 \ --backend onediff \ - --cache_dir ./run_sd-v1-5_quant \ - --height 512 \ - --width 512 \ - --output_file sd-v1-5_quant.png \ + --cache_dir ./run_sdxl_quant \ + --height 1024 \ + --width 1024 \ + --output_file sdxl_quant.png \ --quantize \ --conv_mae_threshold 0.1 \ --linear_mae_threshold 0.2 \ @@ -43,7 +47,7 @@ | Option | Range | Default | Description | | -------------------------------------- | ------ | ------- | ---------------------------------------------------------------------------- | | --conv_mae_threshold 0.9 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. | -| --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | +| --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | | --conv_compute_density_threshold 900 | [0, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. | | --linear_compute_density_threshold 300 | [0, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. | @@ -72,7 +76,7 @@ def parse_args(): parser.add_argument("--cache_dir", default="./run_sd-v1-5") parser.add_argument("--height", type=int, default=1024) parser.add_argument("--width", type=int, default=1024) - parser.add_argument("--num_inference_steps", type=int, default=50) + parser.add_argument("--num_inference_steps", type=int, default=30) parser.add_argument("--conv_mae_threshold", type=float, default=0.2) parser.add_argument("--linear_mae_threshold", type=float, default=0.4) parser.add_argument("--conv_compute_density_threshold", type=int, default=900) @@ -109,14 +113,13 @@ def main(): pipe(prompt=args.prompt, num_inference_steps=1) # Run_inference - for _ in range(4): + for _ in range(5): start_time = time.time() torch.manual_seed(args.seed) image = pipe(prompt=args.prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps).images[0] end_time = time.time() print(f"Inference time: {end_time - start_time:.2f} seconds") - # [onediff_quant.png](https://github.com/siliconflow/onediff/assets/109639975/75cd9407-c9bb-423f-9e70-c15df76ff2b1) save_image(image, args.output_file) if __name__ == "__main__": From c48b777d2338791df2247fe7a4a978955f73f231 Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:44:52 +0800 Subject: [PATCH 10/11] Update deep_cache_speedup.py --- .../utils/deep_cache_speedup.py | 42 +++++-------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/onediff_comfy_nodes/utils/deep_cache_speedup.py b/onediff_comfy_nodes/utils/deep_cache_speedup.py index 8da185741..2580da9b0 100644 --- a/onediff_comfy_nodes/utils/deep_cache_speedup.py +++ b/onediff_comfy_nodes/utils/deep_cache_speedup.py @@ -4,8 +4,7 @@ from onediff.infer_compiler.utils import set_boolean_env_var from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher -from register_comfy import DeepCacheUNet, FastDeepCacheUNet -from onediff.infer_compiler import oneflow_compile + def deep_cache_speedup( model, @@ -17,38 +16,17 @@ def deep_cache_speedup( end_step, *, gen_compile_options=None, - use_oneflow_deepcache_speedup_modelpatcher = True, ): offload_device = model_management.unet_offload_device() - if use_oneflow_deepcache_speedup_modelpatcher: - model_patcher = OneFlowDeepCacheSpeedUpModelPatcher( - model.model, - load_device=model_management.get_torch_device(), - offload_device=offload_device, - cache_layer_id=cache_layer_id, - cache_block_id=cache_block_id, - use_graph=use_graph, - gen_compile_options=gen_compile_options, - ) - else: - model_patcher = model - model_patcher.deep_cache_unet = DeepCacheUNet( - model_patcher.model.diffusion_model, cache_layer_id, cache_block_id - ) - - model_patcher.fast_deep_cache_unet = FastDeepCacheUNet( - model_patcher.model.diffusion_model, cache_layer_id, cache_block_id - ) - if use_graph: - gen_compile_options = gen_compile_options or (lambda x: {}) - compile_options = gen_compile_options(model_patcher.deep_cache_unet) - model_patcher.deep_cache_unet = oneflow_compile( - model_patcher.deep_cache_unet, use_graph=use_graph, options=compile_options, - ) - compile_options = gen_compile_options(model_patcher.fast_deep_cache_unet) - model_patcher.fast_deep_cache_unet = oneflow_compile( - model_patcher.fast_deep_cache_unet, use_graph=use_graph, options=compile_options, - ) + model_patcher = OneFlowDeepCacheSpeedUpModelPatcher( + model.model, + load_device=model_management.get_torch_device(), + offload_device=offload_device, + cache_layer_id=cache_layer_id, + cache_block_id=cache_block_id, + use_graph=use_graph, + gen_compile_options=gen_compile_options, + ) current_t = -1 current_step = -1 From 67c5fb41171d6f8b2eb51c9f5bf444a0f0ff936a Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Mon, 8 Apr 2024 22:20:11 +0800 Subject: [PATCH 11/11] Update text_to_image_online_quant.py --- .../examples/text_to_image_online_quant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py index ee2f84556..c426551f7 100644 --- a/onediff_diffusers_extensions/examples/text_to_image_online_quant.py +++ b/onediff_diffusers_extensions/examples/text_to_image_online_quant.py @@ -46,8 +46,8 @@ | Option | Range | Default | Description | | -------------------------------------- | ------ | ------- | ---------------------------------------------------------------------------- | -| --conv_mae_threshold 0.9 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. | -| --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | +| --conv_mae_threshold 0.1 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. | +| --linear_mae_threshold 0.2 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. | | --conv_compute_density_threshold 900 | [0, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. | | --linear_compute_density_threshold 300 | [0, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. |