Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_quantize_custom_model #756

Merged
merged 20 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ jobs:
if: matrix.test-suite == 'diffusers_examples'
run: |
docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 -m pip install -e .
- if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro')
run: docker exec -w /src/onediff ${{ env.CONTAINER_NAME }} python3 tests/test_quantize_custom_model.py
ccssu marked this conversation as resolved.
Show resolved Hide resolved
- if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro')
run: docker exec -w /src/onediff/onediff_diffusers_extensions ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl_enterprise.py --model /share_nfs/hf_models/stable-diffusion-xl-base-1.0-int8 --width 512 --height 512 --saved_image output_enterprise_sdxl.png
- if: matrix.test-suite == 'diffusers_examples' && startsWith(matrix.image, 'onediff-pro')
Expand Down
126 changes: 126 additions & 0 deletions onediff_diffusers_extensions/examples/text_to_image_online_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""[SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)

## Performance Comparison

Updated on Mon 08 Apr 2024

Timings for 30 steps at 1024x1024
| Accelerator | Baseline (non-optimized) | OneDiff(optimized) | OneDiff Quant(optimized) |
| ----------------------- | ------------------------ | ------------------ | ------------------------ |
| NVIDIA GeForce RTX 3090 | 8.03 s | 4.44 s ( ~44.7%) | 3.34 s ( ~58.4%) |

- torch {version: 2.2.1+cu121}
- oneflow {git_commit: 710818c, version: 0.9.1.dev20240406+cu121, enterprise: True}
ccssu marked this conversation as resolved.
Show resolved Hide resolved

## Install

1. [OneDiff Installation Guide](https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#install-onediff-enterprise)
2. [OneDiffx Installation Guide](https://github.com/siliconflow/onediff/tree/main/onediff_diffusers_extensions#install-and-setup)

## Usage:
> onediff/onediff_diffusers_extensions/examples/text_to_image_online_quant.py

```shell
# Baseline (non-optimized)
$ python text_to_image_online_quant.py \
--model_id /share_nfs/hf_models/stable-diffusion-xl-base-1.0 \
--seed 1 \
--backend torch --height 1024 --width 1024 --output_file sdxl_torch.png
```
```shell
# OneDiff Quant(optimized)
$ python text_to_image_online_quant.py \
--model_id /share_nfs/hf_models/stable-diffusion-xl-base-1.0 \
--seed 1 \
--backend onediff \
--cache_dir ./run_sdxl_quant \
--height 1024 \
--width 1024 \
--output_file sdxl_quant.png \
--quantize \
--conv_mae_threshold 0.1 \
--linear_mae_threshold 0.2 \
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
--conv_compute_density_threshold 900 \
--linear_compute_density_threshold 300
```

| Option | Range | Default | Description |
| -------------------------------------- | ------ | ------- | ---------------------------------------------------------------------------- |
| --conv_mae_threshold 0.9 | [0, 1] | 0.1 | MAE threshold for quantizing convolutional modules to 0.1. |
| --linear_mae_threshold 1 | [0, 1] | 0.2 | MAE threshold for quantizing linear modules to 0.2. |
| --conv_compute_density_threshold 900 | [0, ∞) | 900 | Computational density threshold for quantizing convolutional modules to 900. |
| --linear_compute_density_threshold 300 | [0, ∞) | 300 | Computational density threshold for quantizing linear modules to 300. |

Notes:

1. Set CUDA device using export CUDA_VISIBLE_DEVICES=7.

2. The log *.pt file is cached. Quantization result information can be found in `cache_dir`/quantization_stats.json.

"""
import argparse
import time
import torch
from diffusers import AutoPipelineForText2Image
from onediffx import compile_pipe
from onediff_quant.quantization import QuantizationConfig

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", default="runwayml/stable-diffusion-v1-5")
parser.add_argument("--prompt", default="a photo of an astronaut riding a horse on mars")
parser.add_argument("--output_file", default="astronaut_rides_horse_onediff_quant.png")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--backend", default="onediff", choices=["onediff", "torch"])
parser.add_argument("--quantize", action="store_true")
parser.add_argument("--cache_dir", default="./run_sd-v1-5")
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--num_inference_steps", type=int, default=30)
parser.add_argument("--conv_mae_threshold", type=float, default=0.2)
parser.add_argument("--linear_mae_threshold", type=float, default=0.4)
parser.add_argument("--conv_compute_density_threshold", type=int, default=900)
parser.add_argument("--linear_compute_density_threshold", type=int, default=300)
return parser.parse_args()

def load_model(model_id):
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.to(f"cuda")
return pipe

def compile_and_quantize_model(pipe, cache_dir, quantize, quant_params):
pipe = compile_pipe(pipe)
if quantize:
config = QuantizationConfig.from_settings(**quant_params, cache_dir=cache_dir, plot_calibrate_info=True)
pipe.unet.apply_online_quant(quant_config=config)
return pipe

def save_image(image, output_file):
image.save(output_file)
print(f"Image saved to: {output_file}")

def main():
args = parse_args()
pipe = load_model(args.model_id)
if args.backend == "onediff":
compile_and_quantize_model(pipe, args.cache_dir, args.quantize,
{"conv_mae_threshold": args.conv_mae_threshold,
"linear_mae_threshold": args.linear_mae_threshold,
"conv_compute_density_threshold": args.conv_compute_density_threshold,
"linear_compute_density_threshold": args.linear_compute_density_threshold})

# Warm-up
pipe(prompt=args.prompt, num_inference_steps=1)

# Run_inference
for _ in range(5):
start_time = time.time()
torch.manual_seed(args.seed)
image = pipe(prompt=args.prompt, height=args.height, width=args.width, num_inference_steps=args.num_inference_steps).images[0]
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")

save_image(image, args.output_file)

if __name__ == "__main__":
main()
28 changes: 27 additions & 1 deletion src/onediff/infer_compiler/oneflow/deployable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..utils.log_utils import logger
from ..utils.param_utils import parse_device, check_device
from ..utils.graph_management_utils import graph_file_management
from ..utils.online_quantization_utils import quantize_and_deploy_wrapper
from ..utils.options import OneflowCompileOptions
from ..deployable_module import DeployableModule

Expand All @@ -27,6 +28,7 @@ def __init__(
object.__setattr__(self, "_modules", torch_module._modules)
object.__setattr__(self, "_torch_module", torch_module)
self._deployable_module_enable_dynamic = dynamic
self._deployable_module_quant_config = None
self._deployable_module_options = (
options if options is not None else OneflowCompileOptions()
)
Expand All @@ -48,6 +50,7 @@ def from_existing(cls, existing_module, dynamic=True, options=None):
instance._deployable_module_input_count = (
existing_module._deployable_module_input_count
)
instance._deployable_module_quant_config = existing_module._deployable_module_quant_config

return instance

Expand Down Expand Up @@ -82,7 +85,8 @@ def apply_model(self, *args, **kwargs):
*args, **kwargs
)
return output


@quantize_and_deploy_wrapper
@input_output_processor
@handle_deployable_exception
@graph_file_management
Expand Down Expand Up @@ -168,4 +172,26 @@ def _clear_old_graph(self):
del self._deployable_module_model.oneflow_module

def get_graph_file(self):
return self._deployable_module_options.get("graph_file", None)

def apply_online_quant(self, quant_config):
"""
Applies the provided quantization configuration for online use.

Args:
quant_config (QuantizationConfig): The quantization configuration to apply.

Example:
>>> from onediff_quant.quantization import QuantizationConfig
>>> quant_config = QuantizationConfig.from_settings(
... quantize_conv=True,
... quantize_linear=True,
... conv_mae_threshold=0.005,
... linear_mae_threshold=0.005,
... conv_compute_density_threshold=300,
... linear_compute_density_threshold=100,
... cache_dir=args.cache_dir)
>>> model.apply_online_quant(quant_config)
"""
self._deployable_module_quant_config = quant_config
return self._deployable_module_options.graph_file
52 changes: 52 additions & 0 deletions src/onediff/infer_compiler/utils/online_quantization_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
def online_quantize_model(
model, input_args, input_kwargs,
seed=1, inplace=True,
module_selector=lambda x: x,
quant_config = None,
calibration_info=None,
):
"""Optimize the quantization pipeline.

Returns:
tuple: A tuple containing the quantized model and the quantization
status.
"""

from onediff_quant.quantization import (
OnlineQuantModule,
create_quantization_calculator,
)
if getattr(quant_config, "quantization_calculator", None):
calculator = quant_config.quantization_calculator
else:
calculator = create_quantization_calculator(
model, quant_config, module_selector, seed,
calibration_info=calibration_info,
)
module = OnlineQuantModule(calculator, False, inplace=inplace)

quantized_model, info = module.quantize_with_calibration(
*input_args, **input_kwargs
)
status = module.collect_quantization_status(model, info)

return quantized_model, status


def quantize_and_deploy_wrapper(func):
def wrapper(self: "DeployableModule", *args, **kwargs):
torch_model = self._torch_module
quant_config = self._deployable_module_quant_config
if quant_config:
torch_model, _ = online_quantize_model(
torch_model, args, kwargs,
module_selector=lambda x: x,
quant_config=quant_config,
inplace=True,
)
self._deployable_module_quant_config = None
output = func(self, *args, **kwargs)
return output
return wrapper


7 changes: 7 additions & 0 deletions src/onediff/quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## <div align="center">OneDiff Quant 🚀 NEW</div>
## <div align="center">Documentation</div>
- [Installation Guide](https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#install-onediff-enterprise)
- [How to use Online Quant](../../../onediff_diffusers_extensions/examples/text_to_image_online_quant.py)
- [How to use Offline Quant](./quantize_pipeline.py)
- [How to Quant a custom model](../../../tests/test_quantize_custom_model.py)
- [Community and Support](https://github.com/siliconflow/onediff?tab=readme-ov-file#community-and-support)
15 changes: 4 additions & 11 deletions src/onediff/quantization/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,12 @@ def setup_onediff_quant():


def load_calibration_and_quantize_pipeline(calibration_path, pipe):
calibrate_info = {}
with open(calibration_path, "r") as f:
for line in f.readlines():
line = line.strip()
items = line.split(" ")
calibrate_info[items[0]] = [
float(items[1]),
int(items[2]),
[float(x) for x in items[3].split(",")],
]

from onediff_quant.quantization import CalibrationStorage
from onediff_quant.utils import replace_sub_module_with_quantizable_module

store = CalibrationStorage()
calibrate_info = store.load_from_file(file_path=calibration_path)

for sub_module_name, sub_calibrate_info in calibrate_info.items():
replace_sub_module_with_quantizable_module(
pipe.unet,
Expand Down
Loading
Loading