diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 2ff92fb08..83f509cd0 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -123,8 +123,6 @@ jobs: run: docker exec ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_sdxl.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --compile_unet False --height 512 --width 512 --use_multiple_resolutions True - if: matrix.test-suite == 'diffusers_examples' run: docker exec ${{ env.CONTAINER_NAME }} python3 examples/text_to_image_controlnet.py --base=/share_nfs/hf_models/stable-diffusion-v1-5 --controlnet=/share_nfs/hf_models/sd-controlnet-canny --input_image=/share_nfs/hf_models/input_image_vermeer.png - - if: matrix.test-suite == 'diffusers_examples' - run: docker exec ${{ env.CONTAINER_NAME }} python3 benchmarks/stable_diffusion_2_unet.py --model_id=/share_nfs/hf_models/stable-diffusion-2-1 - if: matrix.test-suite == 'diffusers_examples' run: docker exec ${{ env.CONTAINER_NAME }} bash examples/unet_save_and_load.sh --model_id=/share_nfs/hf_models/stable-diffusion-2-1 diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..eedccf157 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,69 @@ +# Bench OneDiff + +## Build docker image + +```bash +python3 -m pip install -r requirements.txt +python3 docker/main.py --yaml ./docker/config/community-default.yaml +``` + +## Prepare models + +Download models from [here](#About-the-models). If you have downloaded the models before, please skip it. + +## Run OneDiff Benchmark + +Start docker container and run the benchmark by the following command. + +```bash +export BENCHMARK_MODEL_PATH=./benchmark_model +docker compose -f ./docker-compose.onediff:benchmark-community-default.yaml up +``` + +Wait for a while, you will see the following logs, + +```bash +onediff-benchmark-community-default | Run SD1.5(FP16) 1024x1024... +onediff-benchmark-community-default | + python3 ./text_to_image.py --model /benchmark_model/stable-diffusion-v1-5 --warmup 5 --height 1024 --width 1024 +Loading pipeline components...: 43% 3/7 [00:00<00:00, 20.94it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden. +Loading pipeline components...: 100% 7/7 [00:00<00:00, 12.79it/s] +100% 30/30 [00:43<00:00, 1.45s/it] | +100% 30/30 [00:03<00:00, 7.76it/s] | +100% 30/30 [00:03<00:00, 7.74it/s] | +100% 30/30 [00:03<00:00, 7.74it/s] | +100% 30/30 [00:03<00:00, 7.72it/s] | +100% 30/30 [00:03<00:00, 7.72it/s] | +onediff-benchmark-community-default | e2e (30 steps) elapsed: 4.1393163204193115 s, cuda memory usage: 7226.875 MiB +...... +``` + +## About the models + +The structure of `/benchmark_model` should follow this hierarchy: + +```text +benchmark_model +├── stable-diffusion-2-1 +├── stable-diffusion-v1-5 +├── stable-diffusion-xl-base-1.0 +├── stable-diffusion-xl-base-1.0-int8 +``` + +You can obtain the models from [HuggingFace](https://huggingface.co) (excluding the int8 model) or download them from OSS (including the int8 model). The OSS download method is as follows: + +- Obtain ossutil by executing the following command: + + ```bash + wget http://gosspublic.alicdn.com/ossutil/1.7.3/ossutil64 && chmod u+x ossutil64 + ``` + +- Configure ossutil by referring to [the official example](https://www.alibabacloud.com/help/en/oss/developer-reference/configure-ossutil?spm=a2c63.p38356.0.0.337f374a4pcwa4) + ```bash + ossutil64 config + ``` + +- Download the benchmark models finally + + ```bash + ./ossutil64 cp -r oss://oneflow-pro/onediff_benchmark_model/ benchmark_model --update + ``` diff --git a/benchmarks/docker/.gitignore b/benchmarks/docker/.gitignore new file mode 100644 index 000000000..bf4c017b8 --- /dev/null +++ b/benchmarks/docker/.gitignore @@ -0,0 +1,6 @@ +config/ +Dockerfile-* +docker-compose.* +onediff/ +ComfyUI/ +diffusers_quant/ diff --git a/benchmarks/docker/_logger.py b/benchmarks/docker/_logger.py new file mode 100644 index 000000000..4d77abe9f --- /dev/null +++ b/benchmarks/docker/_logger.py @@ -0,0 +1,10 @@ +import logging + +logger = logging.getLogger("onediff-benchmark") +logger.setLevel(logging.INFO) +formatter = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S" +) +console_handler = logging.StreamHandler() +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) diff --git a/benchmarks/docker/_utils.py b/benchmarks/docker/_utils.py new file mode 100644 index 000000000..ec0317159 --- /dev/null +++ b/benchmarks/docker/_utils.py @@ -0,0 +1,184 @@ +import hashlib +import os +import subprocess +import yaml + +from git import Repo + +from _logger import logger + + +def load_yaml(*, file): + if not os.path.exists(file): + raise RuntimeError(f"config file not existed: {file}") + + with open(file, "r") as file: + yaml_content = yaml.safe_load(file) + + return yaml_content + + +def calculate_sha256(file_path): + sha256_hash = hashlib.sha256() + + with open(file_path, "rb") as file: + for chunk in iter(lambda: file.read(4096), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + +def setup_repo(repo_item: dict): + if len(repo_item.keys()) != 1: + raise RuntimeError(f"Only one key required, but got {repo_item.keys()}") + else: + for key in repo_item.keys(): + repo_name = key + repo_item = repo_item.pop(repo_name) + break + + repo_url = repo_item.pop("repo_url") + branch = repo_item.pop("branch") + commit = repo_item.pop("commit", None) + cmds = repo_item.pop("cmds", None) + + repo_path = os.path.join(".", repo_name) + if not os.path.exists(repo_path): + logger.info(f"git clone {repo_url} to {repo_name}, branch: {branch}") + git_repo = Repo.clone_from(repo_url, repo_path, branch=branch) + else: + logger.info(f"git repository {repo_name} has existed, use it") + git_repo = Repo(repo_path) + if commit is not None: + git_repo.git.checkout(commit) + logger.info(f"checkout {repo_name} to {commit}") + docker_commands = f"COPY {repo_name} /app/{repo_name}" + extra_cmds = "" + if cmds is not None: + extra_cmds = ["RUN ", " && \\\n".join(cmds)] + extra_cmds = " ".join(extra_cmds) + extra_cmds = "\n".join([f"WORKDIR /app/{repo_name}", extra_cmds]) + docker_commands = "\n".join([docker_commands, extra_cmds]) + + return docker_commands + + +def generate_docker_file(yaml_file, file_hash, output_dir, **kwargs): + image_config = kwargs + + base_image = image_config.pop("base_image", None) + context_path = image_config.pop("context_path", None) + oneflow_pip_index = image_config.pop("oneflow_pip_index", None) + repos = image_config.pop("repos", None) + proxy = image_config.pop("proxy", None) + set_pip_mirror = image_config.pop("set_pip_mirror", None) + + origin_file_info = f"""#==== Generated from {yaml_file} ==== +# yaml file SHA256: {file_hash} +""" + + dockerfile_head = f""" +#==== Docker Base Image ==== +FROM {base_image} +""" + + if set_pip_mirror is not None: + dockerfile_set_pip_mirror = f"RUN {set_pip_mirror}" + else: + dockerfile_set_pip_mirror = "" + + dockerfile_oneflow = f""" +#==== Install the OneFlow ==== +RUN pip install -f {oneflow_pip_index} oneflow +""" + + repos_cmds = [] + for repo in repos: + repo = setup_repo(repo) + repos_cmds.append(repo) + repos_cmds = "\n\n".join(repos_cmds) + dockerfile_repos = f""" +#==== Download and set up the repos ==== +{repos_cmds} +""" + + docker_post_cmds = f""" +#==== Post setting +WORKDIR /app +""" + + dockerfile_content = "\n".join( + [ + origin_file_info, + dockerfile_head, + dockerfile_set_pip_mirror, + dockerfile_oneflow, + dockerfile_repos, + docker_post_cmds, + ] + ) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + dockerfile_name = os.path.join(output_dir, f"Dockerfile-{file_hash[0:8]}") + logger.info(f"write Dockerfile to {dockerfile_name}") + with open(dockerfile_name, "w") as f: + f.write(dockerfile_content) + return dockerfile_name + + +def build_image(docker_file, imagename, context): + command = ["docker", "build", "-f", docker_file, "-t", imagename, context] + try: + process = subprocess.Popen(command, stdout=subprocess.PIPE, text=True) + for line in iter(process.stdout.readline, ""): + print(line, end="") # Print each line + process.wait() + except subprocess.CalledProcessError as e: + print(f"Command execution failed: {e}") + + +def gen_docker_compose_yaml(container_name, image, envs, volumes, output_dir): + from collections import OrderedDict + + onediff_benchmark_service = { + "container_name": None, + "image": None, + # "entrypoint": "sleep infinity", + "tty": True, + "stdin_open": True, + "privileged": True, + "shm_size": "8g", + "network_mode": "host", + "pids_limit": 2000, + "cap_add": ["SYS_PTRACE"], + "security_opt": ["seccomp=unconfined"], + "environment": ["HF_HUB_OFFLINE=1"], + "volumes": [], + "working_dir": "/app", + "restart": "no", + "command": "/bin/bash -c \"cd /app/onediff/benchmarks && bash run_benchmark.sh /benchmark_model\"", + } + onediff_benchmark_service["container_name"] = container_name + onediff_benchmark_service["image"] = image + onediff_benchmark_service["environment"].extend(envs) + onediff_benchmark_service["volumes"].extend(volumes) + docker_compose_dict = { + "version": "3.8", + "services": {"onediff-benchmark": onediff_benchmark_service}, + } + + yaml_string = yaml.dump(docker_compose_dict) + + docker_compose_file = os.path.join(output_dir, f"docker-compose.{image}.yaml") + + docker_compose_readme = f"""#======== +# run the OneDiff benchmark container by: +# docker compose -f {docker_compose_file} up +#======== + +""" + run_command = f"docker compose -f {docker_compose_file} up" + with open(docker_compose_file, "w") as f: + content = [docker_compose_readme, yaml_string] + f.write("\n".join(content)) + return docker_compose_file, run_command diff --git a/benchmarks/docker/config/community-default.yaml b/benchmarks/docker/config/community-default.yaml new file mode 100644 index 000000000..18f61e0fc --- /dev/null +++ b/benchmarks/docker/config/community-default.yaml @@ -0,0 +1,20 @@ +base_image: nvcr.io/nvidia/pytorch:23.08-py3 +set_pip_mirror: "pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple" +oneflow_pip_index: "https://oneflow-pro.oss-cn-beijing.aliyuncs.com/branch/community/cu122" +repos: + - onediff: + repo_url: https://github.com/Oneflow-Inc/onediff.git + branch: dev_update_benchmark_scripts + cmds: + - "python3 -m pip install transformers==4.27.1 diffusers[torch]==0.19.3" + - "python3 -m pip install -e ." + # - ComfyUI: + # repo_url: "https://github.com/comfyanonymous/ComfyUI.git" + # branch: master + # commit: "6c5990f7dba2d5d0ad04c7ed5a702b067926cbe2" + # cmds: + # - "python3 -m pip install -r requirements.txt" +proxy: "" +volumes: + - '$BENCHMARK_MODEL_PATH:/benchmark_model:ro' +envs: [] diff --git a/benchmarks/docker/main.py b/benchmarks/docker/main.py new file mode 100644 index 000000000..2f831b653 --- /dev/null +++ b/benchmarks/docker/main.py @@ -0,0 +1,107 @@ +import argparse +from datetime import datetime +import os +import sys +from pathlib import Path + +ONEDIFFBOX_ROOT = Path(os.path.abspath(__file__)).parents[0] +sys.path.insert(0, str(ONEDIFFBOX_ROOT)) + +from _utils import ( + calculate_sha256, + setup_repo, + load_yaml, + generate_docker_file, + build_image, + gen_docker_compose_yaml, +) +from _logger import logger + + +def parse_args(): + parser = argparse.ArgumentParser(description="Build OneDiff Box") + formatted_datetime = datetime.now().strftime("%Y%m%d-%H%M") + + parser.add_argument( + "-y", + "--yaml", + type=str, + default="config/community-default.yaml", + ) + parser.add_argument( + "-i", + "--image", + type=str, + default="onediff", + ) + parser.add_argument( + "-t", + "--tag", + type=str, + default=f"benchmark", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default=".", + help="the output directory of Dockerfile and Docker-compose file", + ) + parser.add_argument( + "-c", + "--context", + type=str, + default=".", + help="the path to build context", + ) + parser.add_argument( + "-q", + "--quiet", + action="store_true", + help="the path to build context", + ) + args = parser.parse_args() + return args + + +args = parse_args() + + +if __name__ == "__main__": + image_config = load_yaml(file=args.yaml) + file_hash = calculate_sha256(args.yaml) + + docker_file = generate_docker_file( + args.yaml, file_hash, args.output, **image_config + ) + version = os.path.splitext(os.path.basename(args.yaml))[0] + image_name = f"{args.image}:{args.tag}-{version}" + if not args.quiet: + build_cmd = ( + f"docker build -f {docker_file} -t {args.image}:{args.tag} {args.context}" + ) + print("Ready to build image by:") + r = input(" " + build_cmd + " [y]/n ") + if r == "" or r == "y" or r == "Y": + logger.info(f"building image {image_name}") + build_image(docker_file, image_name, args.context) + else: + print("building cancled") + else: + logger.info(f"building image {image_name}") + build_image(docker_file, image_name, args.context) + + envs = image_config.pop("envs", []) + volumes = image_config.pop( + "volumes", + [ + "$BENCHMARK_MODEL_PATH:/benchmark_model:ro", + ], + ) + compose_file, run_command = gen_docker_compose_yaml( + f"onediff-benchmark-{version}", image_name, envs, volumes, args.output + ) + logger.info(f"write docker-compose file to {compose_file}") + logger.info( + f"run container by:\n {run_command}\n and see {compose_file} for more" + ) diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000..59348f98e --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1 @@ +gitpython diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh new file mode 100644 index 000000000..d47bca7c4 --- /dev/null +++ b/benchmarks/run_benchmark.sh @@ -0,0 +1,59 @@ +#!/bin/bash +set -x + +if [ $# != 1 ]; then + echo "Usage: bash run_benchmark.sh /path/model" && exit 1 +fi +BENCHMARK_MODEL_PATH=$1 + +###################################################################################### +echo "Run SD1.5(FP16) 1024x1024..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-v1-5 --warmup 5 --height 1024 --width 1024 + +echo "Run SD1.5(FP16) 720x1280..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-v1-5 --warmup 5 --height 720 --width 1280 + +echo "Run SD1.5(FP16) 768x768..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-v1-5 --warmup 5 --height 768 --width 768 + +echo "Run SD1.5(FP16) 512x512..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-v1-5 --warmup 5 --height 512 --width 512 + +###################################################################################### +echo "Run SD2.1(FP16) 1024x1024..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-2-1 --warmup 5 --height 1024 --width 1024 + +echo "Run SD2.1(FP16) 720x1280..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-2-1 --warmup 5 --height 720 --width 1280 + +echo "Run SD2.1(FP16) 768x768..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-2-1 --warmup 5 --height 768 --width 768 + +echo "Run SD2.1(FP16) 512x512..." +python3 ./text_to_image.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-2-1 --warmup 5 --height 512 --width 512 + +###################################################################################### +echo "Run SDXL(FP16) 1024x1024..." +python3 ./text_to_image_sdxl_fp16.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0 --warmup 5 --height 1024 --width 1024 + +echo "Run SDXL(FP16) 720x1280..." +python3 ./text_to_image_sdxl_fp16.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0 --warmup 5 --height 720 --width 1280 + +echo "Run SDXL(FP16) 768x768..." +python3 ./text_to_image_sdxl_fp16.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0 --warmup 5 --height 768 --width 768 + +echo "Run SDXL(FP16) 512x512..." +python3 ./text_to_image_sdxl_fp16.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0 --warmup 5 --height 512 --width 512 + +###################################################################################### +echo "Run SDXL(INT8) 1024x1024..." +python3 ./text_to_image_sdxl_quant.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0-int8 --warmup 5 --height 1024 --width 1024 + +echo "Run SDXL(INT8) 720x1280..." +python3 ./text_to_image_sdxl_quant.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0-int8 --warmup 5 --height 720 --width 1280 + +echo "Run SDXL(INT8) 768x768..." +python3 ./text_to_image_sdxl_quant.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0-int8 --warmup 5 --height 768 --width 768 + +echo "Run SDXL(INT8) 512x512..." +python3 ./text_to_image_sdxl_quant.py --model ${BENCHMARK_MODEL_PATH}/stable-diffusion-xl-base-1.0-int8 --warmup 5 --height 512 --width 512 diff --git a/benchmarks/stable_diffusion_2_unet.py b/benchmarks/stable_diffusion_2_unet.py deleted file mode 100644 index db9d02eb2..000000000 --- a/benchmarks/stable_diffusion_2_unet.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -import cv2 -from onediff.infer_compiler import oneflow_compile - -os.environ["ONEFLOW_MLIR_CSE"] = "1" -os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" -os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" -os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" -os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" -os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" -os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" - -os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" -os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" - -os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" -os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" - -os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" -os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - -import click -import oneflow as flow -from diffusers import UNet2DConditionModel -from diffusers.utils import floats_tensor -from tqdm import tqdm -import torch - - -@click.command() -@click.option("--token") -@click.option("--height", default=768) -@click.option("--width", default=768) -@click.option("--repeat", default=1000) -@click.option("--sync_interval", default=50) -@click.option("--model_id", default="stabilityai/stable-diffusion-2") -def benchmark(token, height, width, repeat, sync_interval, model_id): - with torch.no_grad(): - unet = UNet2DConditionModel.from_pretrained( - model_id, - use_auth_token=token, - revision="fp16", - torch_dtype=torch.float16, - subfolder="unet", - ) - unet = unet.to("cuda") - unet_graph = oneflow_compile(unet) - - batch_size = 2 - num_channels = 4 - sizes = (height // 8, width // 8) - noise = ( - floats_tensor((batch_size, num_channels) + sizes) - .to("cuda") - .to(torch.float16) - ) - time_step = torch.tensor([10]).to("cuda") - encoder_hidden_states = ( - floats_tensor((batch_size, 77, 1024)).to("cuda").to(torch.float16) - ) - unet_graph(noise, time_step, encoder_hidden_states) - flow._oneflow_internal.eager.Sync() - import time - - t0 = time.time() - for r in tqdm(range(repeat)): - out = unet_graph(noise, time_step, encoder_hidden_states) - if r == repeat - 1 or r % sync_interval == 0: - flow._oneflow_internal.eager.Sync() - t1 = time.time() - duration = t1 - t0 - throughput = repeat / duration - print( - f"Finish {repeat} steps in {duration:.3f} seconds, average {throughput:.2f}it/s" - ) - - -if __name__ == "__main__": - benchmark() diff --git a/benchmarks/stable_diffusion_v1_5_unet.py b/benchmarks/stable_diffusion_v1_5_unet.py deleted file mode 100644 index 604c4e710..000000000 --- a/benchmarks/stable_diffusion_v1_5_unet.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import cv2 -from onediff.infer_compiler import oneflow_compile - -os.environ["ONEFLOW_MLIR_CSE"] = "1" -os.environ["ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION"] = "1" -os.environ["ONEFLOW_MLIR_ENABLE_ROUND_TRIP"] = "1" -os.environ["ONEFLOW_MLIR_FUSE_FORWARD_OPS"] = "1" -os.environ["ONEFLOW_MLIR_FUSE_OPS_WITH_BACKWARD_IMPL"] = "1" -os.environ["ONEFLOW_MLIR_GROUP_MATMUL"] = "1" -os.environ["ONEFLOW_MLIR_PREFER_NHWC"] = "1" - -os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_CONV_BIAS"] = "1" -os.environ["ONEFLOW_KERNEL_ENABLE_FUSED_LINEAR"] = "1" - -os.environ["ONEFLOW_KERNEL_CONV_CUTLASS_IMPL_ENABLE_TUNING_WARMUP"] = "1" -os.environ["ONEFLOW_KERNEL_CONV_ENABLE_CUTLASS_IMPL"] = "1" - -os.environ["ONEFLOW_CONV_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" -os.environ["ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"] = "1" - -import click -import oneflow as flow -from diffusers import UNet2DConditionModel -from diffusers.utils import floats_tensor -from tqdm import tqdm -import torch - - -class UNetGraph(flow.nn.Graph): - def __init__(self, unet): - super().__init__() - self.unet = unet - self.config.enable_cudnn_conv_heuristic_search_algo(False) - self.config.allow_fuse_add_to_output(True) - - def build(self, latent_model_input, t, text_embeddings): - text_embeddings = flow._C.amp_white_identity(text_embeddings) - return self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings - ).sample - - -@click.command() -@click.option("--token") -@click.option("--repeat", default=1000) -@click.option("--sync_interval", default=50) -@click.option("--model_id", default="runwayml/stable-diffusion-v1-5") -def benchmark(token, repeat, sync_interval, model_id): - with flow.no_grad(): - unet = UNet2DConditionModel.from_pretrained( - model_id, - use_auth_token=token, - revision="fp16", - torch_dtype=torch.float16, - subfolder="unet", - ) - unet = unet.to("cuda") - unet_graph = oneflow_compile(unet) - - batch_size = 2 - num_channels = 4 - sizes = (64, 64) - noise = ( - floats_tensor((batch_size, num_channels) + sizes) - .to("cuda") - .to(torch.float16) - ) - time_step = flow.tensor([10]).to("cuda") - encoder_hidden_states = ( - floats_tensor((batch_size, 77, 768)).to("cuda").to(torch.float16) - ) - unet_graph(noise, time_step, encoder_hidden_states) - flow._oneflow_internal.eager.Sync() - import time - - t0 = time.time() - for r in tqdm(range(repeat)): - out = unet_graph(noise, time_step, encoder_hidden_states) - if r == repeat - 1 or r % sync_interval == 0: - flow._oneflow_internal.eager.Sync() - t1 = time.time() - duration = t1 - t0 - throughput = repeat / duration - print( - f"Finish {repeat} steps in {duration:.3f} seconds, average {throughput:.2f}it/s" - ) - - -if __name__ == "__main__": - print(f"{flow.__path__=}") - print(f"{flow.__version__=}") - benchmark() diff --git a/benchmarks/text_to_image.py b/benchmarks/text_to_image.py new file mode 100644 index 000000000..9a75f48e1 --- /dev/null +++ b/benchmarks/text_to_image.py @@ -0,0 +1,68 @@ +""" +example: python examples/text_to_image.py --height 512 --width 512 --warmup 10 --model xx +""" +import argparse +import time +import torch +import oneflow as flow +from onediff.infer_compiler import oneflow_compile +from onediff.schedulers import EulerDiscreteScheduler +from diffusers import StableDiffusionPipeline + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple demo of image generation.") + parser.add_argument( + "--prompt", type=str, default="a photo of an astronaut riding a horse on mars" + ) + parser.add_argument( + "--model", type=str, default="runwayml/stable-diffusion-v1-5", + ) + parser.add_argument("--height", type=int, default=512) + parser.add_argument("--width", type=int, default=512) + parser.add_argument("--steps", type=int, default=30) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--seed", type=int, default=1) + cmd_args = parser.parse_args() + return cmd_args + + +args = parse_args() + +scheduler = EulerDiscreteScheduler.from_pretrained(args.model, subfolder="scheduler") +pipe = StableDiffusionPipeline.from_pretrained( + args.model, + scheduler=scheduler, + revision="fp16", + variant="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +pipe.unet = oneflow_compile(pipe.unet) +pipe.vae = oneflow_compile(pipe.vae) + +with flow.autocast("cuda"): + for _ in range(args.warmup): + images = pipe( + args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + ).images + + torch.manual_seed(args.seed) + + start_t = time.time() + images = pipe( + args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + ).images + end_t = time.time() + + cuda_memory_usage = flow._oneflow_internal.GetCUDAMemoryUsed() + print( + f"e2e ({args.steps} steps) elapsed: {end_t - start_t} s, cuda memory usage: {cuda_memory_usage} MiB" + ) diff --git a/benchmarks/text_to_image_sdxl_fp16.py b/benchmarks/text_to_image_sdxl_fp16.py new file mode 100644 index 000000000..6338b5fb4 --- /dev/null +++ b/benchmarks/text_to_image_sdxl_fp16.py @@ -0,0 +1,75 @@ +import argparse +import os +import time +import torch +import torch.nn as nn +import oneflow as flow + +# oneflow_compile should be imported before importing any diffusers +from onediff.infer_compiler import oneflow_compile +from onediff.schedulers import EulerDiscreteScheduler +from onediff.optimization import rewrite_self_attention + +from diffusers import StableDiffusionXLPipeline + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, required=True) +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--height", type=int, default=1024) +parser.add_argument("--width", type=int, default=1024) +parser.add_argument("--steps", type=int, default=30) +parser.add_argument("--variant", type=str, default="fp16") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--warmup", type=int, default=1) +parser.add_argument( + "--graph", default=True, type=(lambda x: str(x).lower() in ["true", "1", "yes"]), +) +args = parser.parse_args() + +scheduler = EulerDiscreteScheduler.from_pretrained(args.model, subfolder="scheduler") +pipe = StableDiffusionXLPipeline.from_pretrained( + args.model, + scheduler=scheduler, + torch_dtype=torch.float16, + variant=args.variant, + use_safetensors=True, +) +pipe.to("cuda") + +# if pipe.text_encoder is not None: +# pipe.text_encoder = oneflow_compile(pipe.text_encoder, use_graph=args.graph) +# pipe.text_encoder_2 = oneflow_compile(pipe.text_encoder_2, use_graph=args.graph) +if args.graph: + rewrite_self_attention(pipe.unet) +pipe.unet = oneflow_compile(pipe.unet, use_graph=args.graph) +pipe.vae = oneflow_compile(pipe.vae, use_graph=args.graph) + +for _ in range(args.warmup): + image = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + ).images[0] + +torch.manual_seed(args.seed) + +start_t = time.time() + +image = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, +).images[0] + +end_t = time.time() +cuda_memory_usage = flow._oneflow_internal.GetCUDAMemoryUsed() +print( + f"e2e ({args.steps} steps) elapsed: {end_t - start_t} s, cuda memory usage: {cuda_memory_usage} MiB" +) diff --git a/benchmarks/text_to_image_sdxl_quant.py b/benchmarks/text_to_image_sdxl_quant.py new file mode 100644 index 000000000..3e8b1ce0c --- /dev/null +++ b/benchmarks/text_to_image_sdxl_quant.py @@ -0,0 +1,112 @@ +import argparse +import os +import time +import torch +import torch.nn as nn +from torch._dynamo import allow_in_graph as maybe_allow_in_graph +import oneflow as flow + +# oneflow_compile should be imported before importing any diffusers +from onediff.infer_compiler import oneflow_compile +from onediff.schedulers import EulerDiscreteScheduler +from onediff.optimization import rewrite_self_attention + +from diffusers import StableDiffusionXLPipeline + +try: + import diffusers_quant +except: + print("Skip quantized SDXL since diffusers_quant is not installed.") + exit() +from diffusers_quant.utils import replace_sub_module_with_quantizable_module + +diffusers_quant.enable_load_quantized_model() + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, required=True) +parser.add_argument( + "--prompt", + type=str, + default="street style, detailed, raw photo, woman, face, shot on CineStill 800T", +) +parser.add_argument("--height", type=int, default=1024) +parser.add_argument("--width", type=int, default=1024) +parser.add_argument("--steps", type=int, default=30) +parser.add_argument("--bits", type=int, default=8) +parser.add_argument( + "--static", default=False, type=(lambda x: str(x).lower() in ["true", "1", "yes"]) +) +parser.add_argument( + "--fake_quant", + default=False, + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), +) +parser.add_argument( + "--graph", default=True, type=(lambda x: str(x).lower() in ["true", "1", "yes"]), +) +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--warmup", type=int, default=1) + +args = parser.parse_args() + +calibrate_info = {} +with open(os.path.join(args.model, "calibrate_info.txt"), "r") as f: + for line in f.readlines(): + line = line.strip() + items = line.split(" ") + calibrate_info[items[0]] = [ + float(items[1]), + int(items[2]), + [float(x) for x in items[3].split(",")], + ] + +scheduler = EulerDiscreteScheduler.from_pretrained(args.model, subfolder="scheduler") +pipe = StableDiffusionXLPipeline.from_pretrained( + args.model, scheduler=scheduler, torch_dtype=torch.float16, use_safetensors=True +) +pipe.to("cuda") + +for sub_module_name, sub_calibrate_info in calibrate_info.items(): + replace_sub_module_with_quantizable_module( + pipe.unet, + sub_module_name, + sub_calibrate_info, + args.fake_quant, + args.static, + args.bits, + maybe_allow_in_graph, + ) + +# if pipe.text_encoder is not None: +# pipe.text_encoder = oneflow_compile(pipe.text_encoder, use_graph=args.graph) +# pipe.text_encoder_2 = oneflow_compile(pipe.text_encoder_2, use_graph=args.graph) +if args.graph: + rewrite_self_attention(pipe.unet) +pipe.unet = oneflow_compile(pipe.unet, use_graph=args.graph) +pipe.vae = oneflow_compile(pipe.vae, use_graph=args.graph) + +for _ in range(args.warmup): + image = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, + ).images[0] + +torch.manual_seed(args.seed) + +start_t = time.time() + +image = pipe( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.steps, +).images[0] + +end_t = time.time() +cuda_memory_usage = flow._oneflow_internal.GetCUDAMemoryUsed() +print( + f"e2e ({args.steps} steps) elapsed: {end_t - start_t} s, cuda memory usage: {cuda_memory_usage} MiB" +)