diff --git a/.gitignore b/.gitignore index 1b4b7915b..6474e6a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -8,12 +8,17 @@ *.suo *.user +# Common IREE source/build paths +iree/ +iree-build/ + # macOS files .DS_Store # CMake artifacts build/ build-*/ +_build/ # Python __pycache__ @@ -30,6 +35,12 @@ wheelhouse *.safetensors *.gguf *.vmfb +genfiles/ +*.zip +tmp/ # Known inference result blobs *output*.png + +# Log files. +*.log diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 4e8c25dda..9c4ee3b3e 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -16,34 +16,33 @@ pip install pillow python -m shortfin_apps.sd.server --help ``` +## Run tests + + - From SHARK-Platform/shortfin: + ``` + pytest --system=amdgpu -k "sd" + ``` + The tests run with splat weights. + + ## Run on MI300x - Follow quick start - - Download runtime artifacts (vmfbs, weights): - + - Navigate to shortfin/ (only necessary if you're using following CLI exactly.) ``` -mkdir vmfbs -mkdir weights -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb - -# You can download real weights with: -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip -# Splat weights: -wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_splat_1023.zip +cd shortfin/ ``` - - Unzip the downloaded weights archive to /weights - Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples): +The server will prepare runtime artifacts for you. + ``` -python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0 +python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 ``` -with splat: + - Run with splat(empty) weights: ``` -python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0 +python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat ``` - Run a request in a separate shell: ``` diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py new file mode 100644 index 000000000..a83b63e48 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -0,0 +1,202 @@ +from iree.build import * +import itertools +import os +import shortfin.array as sfnp + +from shortfin_apps.sd.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) +default_config_json = os.path.join(parent, "examples", "sdxl_config_i8.json") + +dtype_to_filetag = { + sfnp.float16: "fp16", + sfnp.float32: "fp32", + sfnp.int8: "i8", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "11022024" +SDXL_BUCKET = ( + f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" +) +SDXL_WEIGHTS_BUCKET = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/" +) + + +def get_mlir_filenames(model_params: ModelParams): + mlir_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + mlir_filenames.extend([stem + ".mlir"]) + return mlir_filenames + + +def get_vmfb_filenames(model_params: ModelParams, target: str = "gfx942"): + vmfb_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + vmfb_filenames.extend([stem + "_" + target + ".vmfb"]) + return vmfb_filenames + + +def get_params_filenames(model_params: ModelParams, splat: bool): + params_filenames = [] + base = ( + "stable_diffusion_xl_base_1_0" + if model_params.base_model_name.lower() == "sdxl" + else model_params.base_model_name + ) + modnames = ["clip", "vae"] + mod_precs = [ + dtype_to_filetag[model_params.clip_dtype], + dtype_to_filetag[model_params.unet_dtype], + ] + if model_params.use_i8_punet: + modnames.append("punet") + mod_precs.append("i8") + else: + modnames.append("unet") + mod_precs.append(dtype_to_filetag[model_params.unet_dtype]) + if splat == "True": + for idx, mod in enumerate(modnames): + params_filenames.extend( + ["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])] + ) + else: + for idx, mod in enumerate(modnames): + params_filenames.extend( + [base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"] + ) + return params_filenames + + +def get_file_stems(model_params: ModelParams): + file_stems = [] + base = ( + ["stable_diffusion_xl_base_1_0"] + if model_params.base_model_name.lower() == "sdxl" + else [model_params.base_model_name] + ) + mod_names = { + "clip": "clip", + "unet": "punet" if model_params.use_i8_punet else "unet", + "scheduler": model_params.scheduler_id + "Scheduler", + "vae": "vae", + } + for mod, modname in mod_names.items(): + ord_params = [ + base, + [modname], + ] + bsizes = [] + for bs in getattr(model_params, f"{mod}_batch_sizes", [1]): + bsizes.extend([f"bs{bs}"]) + ord_params.extend([bsizes]) + if mod in ["unet", "clip"]: + ord_params.extend([[str(model_params.max_seq_len)]]) + if mod in ["unet", "vae", "scheduler"]: + dims = [] + for dim_pair in model_params.dims: + dim_pair_str = [str(d) for d in dim_pair] + dims.extend(["x".join(dim_pair_str)]) + ord_params.extend([dims]) + if mod == "scheduler": + dtype_str = dtype_to_filetag[model_params.unet_dtype] + elif mod != "unet": + dtype_str = dtype_to_filetag[ + getattr(model_params, f"{mod}_dtype", sfnp.float16) + ] + else: + dtype_str = ( + "i8" + if model_params.use_i8_punet + else dtype_to_filetag[model_params.unet_dtype] + ) + ord_params.extend([[dtype_str]]) + for x in list(itertools.product(*ord_params)): + file_stems.extend(["_".join(x)]) + return file_stems + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx): + out_file = ctx.allocate_file(filename).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + filekey = f"{ctx.path}/{filename}" + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL submodels.") +def sdxl( + model_json=cl_arg( + "model_json", + default=default_config_json, + help="Local config filepath", + ), + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + splat=cl_arg( + "splat", default=False, type=str, help="Download empty weights (for testing)" + ), +): + model_params = ModelParams.load_json(model_json) + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + mlir_bucket = SDXL_BUCKET + "mlir/" + vmfb_bucket = SDXL_BUCKET + "vmfbs/" + + mlir_filenames = get_mlir_filenames(model_params) + mlir_urls = get_url_map(mlir_filenames, mlir_bucket) + for f, url in mlir_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + vmfb_filenames = get_vmfb_filenames(model_params, target=target) + vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket) + for f, url in vmfb_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + params_filenames = get_params_filenames(model_params, splat) + params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) + for f, url in params_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 0d68aad8e..3dda6edfc 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -43,9 +43,15 @@ class ModelParams: # Same for VAE. vae_batch_sizes: list[int] + # Same for scheduler. + scheduler_batch_sizes: list[int] + # Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]] dims: list[list[int]] + # Scheduler id. + scheduler_id: str = "EulerDiscrete" + base_model_name: str = "SDXL" # Name of the IREE module for each submodel. clip_module_name: str = "compiled_clip" @@ -59,11 +65,13 @@ class ModelParams: # Classifer free guidance mode. If set to false, only positive prompts will matter. cfg_mode = True - # DTypes (basically defaults): + # DTypes (not necessarily weights precision): clip_dtype: sfnp.DType = sfnp.float16 unet_dtype: sfnp.DType = sfnp.float16 vae_dtype: sfnp.DType = sfnp.float16 + use_i8_punet: bool = False + # ABI of the module. module_abi_version: int = 1 diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index f19591bd9..af8423a11 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -48,6 +48,7 @@ def __init__( tokenizers: list[Tokenizer], model_params: ModelParams, fibers_per_device: int, + workers_per_device: int = 1, prog_isolation: str = "per_fiber", show_progress: bool = False, trace_execution: bool = False, @@ -64,16 +65,20 @@ def __init__( self.inference_programs: dict[str, sf.Program] = {} self.trace_execution = trace_execution self.show_progress = show_progress + self.workers_per_device = workers_per_device self.fibers_per_device = fibers_per_device self.prog_isolation = prog_isolations[prog_isolation] self.workers = [] self.fibers = [] self.fiber_status = [] for idx, device in enumerate(self.sysman.ls.devices): - for i in range(self.fibers_per_device): + for i in range(self.workers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") - fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) + for i in range(self.fibers_per_device): + fiber = sysman.ls.create_fiber( + self.workers[i % len(self.workers)], devices=[device] + ) self.fibers.append(fiber) self.fiber_status.append(0) diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json index fa2af4a1e..2e03e4603 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_fp16.json @@ -10,6 +10,9 @@ "vae_batch_sizes": [ 1 ], + "scheduler_batch_sizes": [ + 1 + ], "unet_module_name": "compiled_unet", "unet_fn_name": "run_forward", "dims": [ diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json index dff7d4d2d..804947d8f 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json @@ -10,9 +10,13 @@ "vae_batch_sizes": [ 1 ], + "scheduler_batch_sizes": [ + 1 + ], "unet_dtype": "float16", "unet_module_name": "compiled_punet", "unet_fn_name": "main", + "use_i8_punet": true, "dims": [ [ 1024, diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 837afb152..177361c06 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -11,6 +11,9 @@ from pathlib import Path import sys import os +import io + +from iree.build import * import uvicorn.logging @@ -29,12 +32,15 @@ from .components.manager import SystemManager from .components.service import GenerateService from .components.tokenizer import Tokenizer +from .components.builders import sdxl from shortfin.support.logging_setup import configure_main_logger logger = configure_main_logger("server") +THIS_DIR = Path(__file__).resolve().parent + @asynccontextmanager async def lifespan(app: FastAPI): @@ -99,25 +105,41 @@ def configure(args) -> SystemManager: show_progress=args.show_progress, trace_execution=args.trace_execution, ) - sm.load_inference_module(args.clip_vmfb, component="clip") - sm.load_inference_module(args.unet_vmfb, component="unet") - sm.load_inference_module(args.scheduler_vmfb, component="scheduler") - sm.load_inference_module(args.vae_vmfb, component="vae") - sm.load_inference_parameters( - *args.clip_params, parameter_scope="model", component="clip" - ) - sm.load_inference_parameters( - *args.unet_params, - parameter_scope="model", - component="unet", - ) - sm.load_inference_parameters( - *args.vae_params, parameter_scope="model", component="vae" - ) + vmfbs, params = get_modules(args) + for key, vmfblist in vmfbs.items(): + for vmfb in vmfblist: + sm.load_inference_module(vmfb, component=key) + for key, datasets in params.items(): + sm.load_inference_parameters(*datasets, parameter_scope="model", component=key) services[sm.name] = sm return sysman +def get_modules(args): + vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} + params = {"clip": [], "unet": [], "vae": []} + mod = load_build_module(os.path.join(THIS_DIR, "components", "builders.py")) + out_file = io.StringIO() + iree_build_main( + mod, + args=[ + f"--model_json={args.model_config}", + f"--target={args.target}", + f"--splat={args.splat}", + ], + stdout=out_file, + ) + filenames = out_file.getvalue().strip().split("\n") + for name in filenames: + for key in vmfbs.keys(): + if key in name.lower(): + if any([x in name for x in [".irpa", ".safetensors", ".gguf"]]): + params[key].extend([name]) + elif "vmfb" in name: + vmfbs[key].extend([name]) + return vmfbs, params + + def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -138,6 +160,13 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): choices=["local-task", "hip", "amdgpu"], help="Primary inferencing device", ) + parser.add_argument( + "--target", + type=str, + required=False, + default="gfx942", + help="Primary inferencing device LLVM target arch.", + ) parser.add_argument( "--device_ids", type=int, @@ -161,43 +190,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): required=True, help="Path to the model config file", ) - parser.add_argument( - "--clip_vmfb", - type=Path, - required=True, - help="Model VMFB to load", - ) - parser.add_argument( - "--unet_vmfb", - type=Path, - required=True, - help="Model VMFB to load", - ) - parser.add_argument("--scheduler_vmfb", type=Path, help="Scheduler VMFB to load.") - parser.add_argument( - "--vae_vmfb", - type=Path, - required=True, - help="Model VMFB to load", - ) - parser.add_argument( - "--clip_params", - type=Path, - nargs="*", - help="Parameter archives to load", - ) - parser.add_argument( - "--unet_params", - type=Path, - nargs="*", - help="Parameter archives to load", - ) - parser.add_argument( - "--vae_params", - type=Path, - nargs="*", - help="Parameter archives to load", - ) parser.add_argument( "--fibers_per_device", type=int, @@ -224,6 +216,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): action="store_true", help="Enable tracing of program modules.", ) + parser.add_argument( + "--splat", + action="store_true", + help="Use splat (empty) parameter files, usually for testing.", + ) log_levels = { "info": logging.INFO, "debug": logging.DEBUG, @@ -234,7 +231,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): log_level = log_levels[args.log_level] logger.setLevel(log_level) - + logger.addHandler(logging.FileHandler("shortfin_sd.log")) global sysman sysman = configure(args) uvicorn.run( diff --git a/shortfin/requirements-iree-compiler.txt b/shortfin/requirements-iree-compiler.txt index 4afd96c5e..65d809a8a 100644 --- a/shortfin/requirements-iree-compiler.txt +++ b/shortfin/requirements-iree-compiler.txt @@ -1,4 +1,4 @@ # Keep in sync with IREE_REF in CI and GIT_TAG in CMakeLists.txt -f https://iree.dev/pip-release-links.html -iree-compiler==20241029.1062 -iree-runtime==20241029.1062 +iree-compiler==20241101.1065 +iree-runtime==20241101.1065 diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index 366fb9c2d..8fe8de5b3 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -1,6 +1,7 @@ import json import requests import time +import asyncio import base64 import pytest import subprocess @@ -8,6 +9,8 @@ import socket import sys import copy +import math +import tempfile from contextlib import closing from datetime import datetime as dt @@ -30,58 +33,25 @@ } -def sd_artifacts(target: str = "gfx942"): - return { - "model_config": "sdxl_config_i8.json", - "clip_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_{target}.vmfb", - "scheduler_vmfb": f"stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_{target}.vmfb", - "unet_vmfb": f"stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_{target}.vmfb", - "vae_vmfb": f"stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_{target}.vmfb", - "clip_params": "clip_splat_fp16.irpa", - "unet_params": "punet_splat_i8.irpa", - "vae_params": "vae_splat_fp16.irpa", - } - - -cache = os.path.abspath("./tmp/sharktank/sd/") - - def start_server(fibers_per_device=1, isolation="per_fiber"): - # Download model if it doesn't exist - vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/" - weights_bucket = ( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/" - ) - configs_bucket = ( - "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/configs/" - ) - for artifact, path in sd_artifacts().items(): - if "vmfb" in artifact: - bucket = vmfbs_bucket - elif "params" in artifact: - bucket = weights_bucket - else: - bucket = configs_bucket - address = bucket + path - local_file = os.path.join(cache, path) - if not os.path.exists(local_file): - print("Downloading artifact from " + address) - r = requests.get(address, allow_redirects=True) - with open(local_file, "wb") as lf: - lf.write(r.content) # Start the server srv_args = [ "python", "-m", "shortfin_apps.sd.server", ] - for arg in sd_artifacts().keys(): - artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}" - srv_args.extend([artifact_arg]) + with open("sdxl_config_i8.json", "wb") as f: + r = requests.get( + "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/11022024/configs/sdxl_config_i8.json", + allow_redirects=True, + ) + f.write(r.content) srv_args.extend( [ + f"--model_config=sdxl_config_i8.json", f"--fibers_per_device={fibers_per_device}", f"--isolation={isolation}", + f"--splat", ] ) runner = ServerRunner(srv_args) @@ -92,10 +62,6 @@ def start_server(fibers_per_device=1, isolation="per_fiber"): @pytest.fixture(scope="module") def sd_server_fpd1(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - runner = start_server(fibers_per_device=1) yield runner @@ -106,10 +72,6 @@ def sd_server_fpd1(): @pytest.fixture(scope="module") def sd_server_fpd1_per_call(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - runner = start_server(fibers_per_device=1, isolation="per_call") yield runner @@ -120,10 +82,6 @@ def sd_server_fpd1_per_call(): @pytest.fixture(scope="module") def sd_server_fpd2(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - runner = start_server(fibers_per_device=2) yield runner @@ -134,10 +92,6 @@ def sd_server_fpd2(): @pytest.fixture(scope="module") def sd_server_fpd8(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - runner = start_server(fibers_per_device=8) yield runner @@ -181,6 +135,23 @@ def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): assert status_code == 200 +@pytest.mark.slow +@pytest.mark.system("amdgpu") +def test_sd_server_bs64_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=64) + assert len(imgs) == 64 + assert status_code == 200 + + +@pytest.mark.slow +@pytest.mark.xfail(reason="Unexpectedly large client batch.") +@pytest.mark.system("amdgpu") +def test_sd_server_bs512_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=512) + assert len(imgs) == 512 + assert status_code == 200 + + class ServerRunner: def __init__(self, args): port = str(find_free_port()) diff --git a/shortfin/tests/conftest.py b/shortfin/tests/conftest.py index 247876ad5..f085f1047 100644 --- a/shortfin/tests/conftest.py +++ b/shortfin/tests/conftest.py @@ -24,6 +24,9 @@ def pytest_configure(config): config.addinivalue_line( "markers", "system(name): mark test to run only on a named system" ) + config.addinivalue_line( + "markers", "slow: mark test to run in a separate, slow suite." + ) def pytest_runtest_setup(item):