Skip to content

Commit

Permalink
(shortfin-sd) Adds iree.build artifact fetching. (#411)
Browse files Browse the repository at this point in the history
Also adds two _slow_ tests for testing larger SDXL server loads that
will not trigger in any workflows yet.

This is missing a few things:
- ad-hoc artifacts fetching (e.g. someone inits the server with only
1024x1024 and wants to fetch and load modules for other output shapes
ad-hoc when requested by client)
 - compile integrate (currently pulls precompiled vmfbs and weights)
 
 It should eventually cover:
  - compile (short-term)
  - export (short/medium-term)

---------
  • Loading branch information
monorimet authored Nov 5, 2024
1 parent 46debb4 commit e282fbc
Show file tree
Hide file tree
Showing 11 changed files with 335 additions and 132 deletions.
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
*.suo
*.user

# Common IREE source/build paths
iree/
iree-build/

# macOS files
.DS_Store

# CMake artifacts
build/
build-*/
_build/

# Python
__pycache__
Expand All @@ -30,6 +35,12 @@ wheelhouse
*.safetensors
*.gguf
*.vmfb
genfiles/
*.zip
tmp/

# Known inference result blobs
*output*.png

# Log files.
*.log
33 changes: 16 additions & 17 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,33 @@ pip install pillow
python -m shortfin_apps.sd.server --help
```

## Run tests

- From SHARK-Platform/shortfin:
```
pytest --system=amdgpu -k "sd"
```
The tests run with splat weights.


## Run on MI300x

- Follow quick start

- Download runtime artifacts (vmfbs, weights):

- Navigate to shortfin/ (only necessary if you're using following CLI exactly.)
```
mkdir vmfbs
mkdir weights
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb
# You can download real weights with:
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip
# Splat weights:
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_splat_1023.zip
cd shortfin/
```
- Unzip the downloaded weights archive to /weights
- Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples):

The server will prepare runtime artifacts for you.

```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0
```
with splat:
- Run with splat(empty) weights:
```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat
```
- Run a request in a separate shell:
```
Expand Down
202 changes: 202 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from iree.build import *
import itertools
import os
import shortfin.array as sfnp

from shortfin_apps.sd.components.config_struct import ModelParams

this_dir = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(this_dir)
default_config_json = os.path.join(parent, "examples", "sdxl_config_i8.json")

dtype_to_filetag = {
sfnp.float16: "fp16",
sfnp.float32: "fp32",
sfnp.int8: "i8",
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11022024"
SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/"
)
SDXL_WEIGHTS_BUCKET = (
"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/"
)


def get_mlir_filenames(model_params: ModelParams):
mlir_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
mlir_filenames.extend([stem + ".mlir"])
return mlir_filenames


def get_vmfb_filenames(model_params: ModelParams, target: str = "gfx942"):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
vmfb_filenames.extend([stem + "_" + target + ".vmfb"])
return vmfb_filenames


def get_params_filenames(model_params: ModelParams, splat: bool):
params_filenames = []
base = (
"stable_diffusion_xl_base_1_0"
if model_params.base_model_name.lower() == "sdxl"
else model_params.base_model_name
)
modnames = ["clip", "vae"]
mod_precs = [
dtype_to_filetag[model_params.clip_dtype],
dtype_to_filetag[model_params.unet_dtype],
]
if model_params.use_i8_punet:
modnames.append("punet")
mod_precs.append("i8")
else:
modnames.append("unet")
mod_precs.append(dtype_to_filetag[model_params.unet_dtype])
if splat == "True":
for idx, mod in enumerate(modnames):
params_filenames.extend(
["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])]
)
else:
for idx, mod in enumerate(modnames):
params_filenames.extend(
[base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"]
)
return params_filenames


def get_file_stems(model_params: ModelParams):
file_stems = []
base = (
["stable_diffusion_xl_base_1_0"]
if model_params.base_model_name.lower() == "sdxl"
else [model_params.base_model_name]
)
mod_names = {
"clip": "clip",
"unet": "punet" if model_params.use_i8_punet else "unet",
"scheduler": model_params.scheduler_id + "Scheduler",
"vae": "vae",
}
for mod, modname in mod_names.items():
ord_params = [
base,
[modname],
]
bsizes = []
for bs in getattr(model_params, f"{mod}_batch_sizes", [1]):
bsizes.extend([f"bs{bs}"])
ord_params.extend([bsizes])
if mod in ["unet", "clip"]:
ord_params.extend([[str(model_params.max_seq_len)]])
if mod in ["unet", "vae", "scheduler"]:
dims = []
for dim_pair in model_params.dims:
dim_pair_str = [str(d) for d in dim_pair]
dims.extend(["x".join(dim_pair_str)])
ord_params.extend([dims])
if mod == "scheduler":
dtype_str = dtype_to_filetag[model_params.unet_dtype]
elif mod != "unet":
dtype_str = dtype_to_filetag[
getattr(model_params, f"{mod}_dtype", sfnp.float16)
]
else:
dtype_str = (
"i8"
if model_params.use_i8_punet
else dtype_to_filetag[model_params.unet_dtype]
)
ord_params.extend([[dtype_str]])
for x in list(itertools.product(*ord_params)):
file_stems.extend(["_".join(x)])
return file_stems


def get_url_map(filenames: list[str], bucket: str):
file_map = {}
for filename in filenames:
file_map[filename] = f"{bucket}{filename}"
return file_map


def needs_update(ctx):
stamp = ctx.allocate_file("version.txt")
stamp_path = stamp.get_fs_path()
if os.path.exists(stamp_path):
with open(stamp_path, "r") as s:
ver = s.read()
if ver != ARTIFACT_VERSION:
return True
else:
with open(stamp_path, "w") as s:
s.write(ARTIFACT_VERSION)
return True
return False


def needs_file(filename, ctx):
out_file = ctx.allocate_file(filename).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
filekey = f"{ctx.path}/{filename}"
ctx.executor.all[filekey] = None
needed = True
return needed


@entrypoint(description="Retreives a set of SDXL submodels.")
def sdxl(
model_json=cl_arg(
"model_json",
default=default_config_json,
help="Local config filepath",
),
target=cl_arg(
"target",
default="gfx942",
help="IREE target architecture.",
),
splat=cl_arg(
"splat", default=False, type=str, help="Download empty weights (for testing)"
),
):
model_params = ModelParams.load_json(model_json)
ctx = executor.BuildContext.current()
update = needs_update(ctx)

mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"

mlir_filenames = get_mlir_filenames(model_params)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
for f, url in mlir_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

vmfb_filenames = get_vmfb_filenames(model_params, target=target)
vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket)
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
params_filenames = get_params_filenames(model_params, splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames


if __name__ == "__main__":
iree_build_main()
10 changes: 9 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class ModelParams:
# Same for VAE.
vae_batch_sizes: list[int]

# Same for scheduler.
scheduler_batch_sizes: list[int]

# Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]]
dims: list[list[int]]

# Scheduler id.
scheduler_id: str = "EulerDiscrete"

base_model_name: str = "SDXL"
# Name of the IREE module for each submodel.
clip_module_name: str = "compiled_clip"
Expand All @@ -59,11 +65,13 @@ class ModelParams:
# Classifer free guidance mode. If set to false, only positive prompts will matter.
cfg_mode = True

# DTypes (basically defaults):
# DTypes (not necessarily weights precision):
clip_dtype: sfnp.DType = sfnp.float16
unet_dtype: sfnp.DType = sfnp.float16
vae_dtype: sfnp.DType = sfnp.float16

use_i8_punet: bool = False

# ABI of the module.
module_abi_version: int = 1

Expand Down
9 changes: 7 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
tokenizers: list[Tokenizer],
model_params: ModelParams,
fibers_per_device: int,
workers_per_device: int = 1,
prog_isolation: str = "per_fiber",
show_progress: bool = False,
trace_execution: bool = False,
Expand All @@ -64,16 +65,20 @@ def __init__(
self.inference_programs: dict[str, sf.Program] = {}
self.trace_execution = trace_execution
self.show_progress = show_progress
self.workers_per_device = workers_per_device
self.fibers_per_device = fibers_per_device
self.prog_isolation = prog_isolations[prog_isolation]
self.workers = []
self.fibers = []
self.fiber_status = []
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.fibers_per_device):
for i in range(self.workers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
fiber = sysman.ls.create_fiber(worker, devices=[device])
self.workers.append(worker)
for i in range(self.fibers_per_device):
fiber = sysman.ls.create_fiber(
self.workers[i % len(self.workers)], devices=[device]
)
self.fibers.append(fiber)
self.fiber_status.append(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"vae_batch_sizes": [
1
],
"scheduler_batch_sizes": [
1
],
"unet_module_name": "compiled_unet",
"unet_fn_name": "run_forward",
"dims": [
Expand Down
4 changes: 4 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
"vae_batch_sizes": [
1
],
"scheduler_batch_sizes": [
1
],
"unet_dtype": "float16",
"unet_module_name": "compiled_punet",
"unet_fn_name": "main",
"use_i8_punet": true,
"dims": [
[
1024,
Expand Down
Loading

0 comments on commit e282fbc

Please sign in to comment.