Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(shortfin-sd) Adds iree.build artifact fetching. #411

Merged
merged 11 commits into from
Nov 5, 2024
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
*.suo
*.user

# Common IREE source/build paths
iree/
iree-build/

# macOS files
.DS_Store

# CMake artifacts
build/
build-*/
_build/

# Python
__pycache__
Expand All @@ -30,6 +35,12 @@ wheelhouse
*.safetensors
*.gguf
*.vmfb
genfiles/
*.zip
tmp/

# Known inference result blobs
*output*.png

# Log files.
*.log
33 changes: 16 additions & 17 deletions shortfin/python/shortfin_apps/sd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,33 @@ pip install pillow
python -m shortfin_apps.sd.server --help
```

## Run tests

- From SHARK-Platform/shortfin:
```
pytest --system=amdgpu -k "sd"
```
The tests run with splat weights.


## Run on MI300x

- Follow quick start

- Download runtime artifacts (vmfbs, weights):

- Navigate to shortfin/ (only necessary if you're using following CLI exactly.)
```
mkdir vmfbs
mkdir weights
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb -O vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb

# You can download real weights with:
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_weights_1023.zip
# Splat weights:
wget https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/sfsd_splat_1023.zip
cd shortfin/
```
- Unzip the downloaded weights archive to /weights
- Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples):

The server will prepare runtime artifacts for you.

```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/stable_diffusion_xl_base_1_0_text_encoder_fp16.safetensors --unet_params=./weights/stable_diffusion_xl_base_1_0_punet_dataset_i8.irpa --vae_params=./weights/stable_diffusion_xl_base_1_0_vae_fp16.safetensors --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0
```
with splat:
- Run with splat(empty) weights:
```
python -m shortfin_apps.sd.server --model_config=./sdxl_config_i8.json --clip_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_fp16_text_encoder_gfx942.vmfb --unet_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_64_1024x1024_i8_punet_gfx942.vmfb --scheduler_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_EulerDiscreteScheduler_bs1_1024x1024_fp16_gfx942.vmfb --vae_vmfb=./vmfbs/stable_diffusion_xl_base_1_0_bs1_1024x1024_fp16_vae_gfx942.vmfb --clip_params=./weights/clip_splat.irpa --unet_params=./weights/punet_splat_18.irpa --vae_params=./weights/vae_splat.irpa --device=amdgpu --device_ids=0
python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat
```
- Run a request in a separate shell:
```
Expand Down
202 changes: 202 additions & 0 deletions shortfin/python/shortfin_apps/sd/components/builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from iree.build import *
import itertools
import os
import shortfin.array as sfnp

from shortfin_apps.sd.components.config_struct import ModelParams

this_dir = os.path.dirname(os.path.abspath(__file__))
parent = os.path.dirname(this_dir)
default_config_json = os.path.join(parent, "examples", "sdxl_config_i8.json")

dtype_to_filetag = {
Copy link
Contributor Author

@monorimet monorimet Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we rename the artifacts to match sfnp.Dtype attributes instead of doing little workarounds like this for old naming conventions. Once the exports are spinning and publishing regularly we can make changes with control..

sfnp.float16: "fp16",
sfnp.float32: "fp32",
sfnp.int8: "i8",
sfnp.bfloat16: "bf16",
}

ARTIFACT_VERSION = "11022024"
SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/"
)
SDXL_WEIGHTS_BUCKET = (
"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/weights/"
)


def get_mlir_filenames(model_params: ModelParams):
mlir_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
mlir_filenames.extend([stem + ".mlir"])
return mlir_filenames


def get_vmfb_filenames(model_params: ModelParams, target: str = "gfx942"):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
vmfb_filenames.extend([stem + "_" + target + ".vmfb"])
return vmfb_filenames


def get_params_filenames(model_params: ModelParams, splat: bool):
params_filenames = []
base = (
"stable_diffusion_xl_base_1_0"
if model_params.base_model_name.lower() == "sdxl"
else model_params.base_model_name
)
modnames = ["clip", "vae"]
mod_precs = [
dtype_to_filetag[model_params.clip_dtype],
dtype_to_filetag[model_params.unet_dtype],
]
if model_params.use_i8_punet:
modnames.append("punet")
mod_precs.append("i8")
else:
modnames.append("unet")
mod_precs.append(dtype_to_filetag[model_params.unet_dtype])
if splat == "True":
for idx, mod in enumerate(modnames):
params_filenames.extend(
["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])]
)
else:
for idx, mod in enumerate(modnames):
params_filenames.extend(
[base + "_" + mod + "_dataset_" + mod_precs[idx] + ".irpa"]
)
return params_filenames


def get_file_stems(model_params: ModelParams):
file_stems = []
base = (
["stable_diffusion_xl_base_1_0"]
if model_params.base_model_name.lower() == "sdxl"
else [model_params.base_model_name]
)
mod_names = {
"clip": "clip",
"unet": "punet" if model_params.use_i8_punet else "unet",
"scheduler": model_params.scheduler_id + "Scheduler",
"vae": "vae",
}
for mod, modname in mod_names.items():
ord_params = [
base,
[modname],
]
bsizes = []
for bs in getattr(model_params, f"{mod}_batch_sizes", [1]):
bsizes.extend([f"bs{bs}"])
ord_params.extend([bsizes])
if mod in ["unet", "clip"]:
ord_params.extend([[str(model_params.max_seq_len)]])
if mod in ["unet", "vae", "scheduler"]:
dims = []
for dim_pair in model_params.dims:
dim_pair_str = [str(d) for d in dim_pair]
dims.extend(["x".join(dim_pair_str)])
ord_params.extend([dims])
if mod == "scheduler":
dtype_str = dtype_to_filetag[model_params.unet_dtype]
elif mod != "unet":
dtype_str = dtype_to_filetag[
getattr(model_params, f"{mod}_dtype", sfnp.float16)
]
else:
dtype_str = (
"i8"
if model_params.use_i8_punet
else dtype_to_filetag[model_params.unet_dtype]
)
ord_params.extend([[dtype_str]])
for x in list(itertools.product(*ord_params)):
file_stems.extend(["_".join(x)])
return file_stems


def get_url_map(filenames: list[str], bucket: str):
file_map = {}
for filename in filenames:
file_map[filename] = f"{bucket}{filename}"
return file_map


def needs_update(ctx):
stamp = ctx.allocate_file("version.txt")
stamp_path = stamp.get_fs_path()
if os.path.exists(stamp_path):
with open(stamp_path, "r") as s:
ver = s.read()
if ver != ARTIFACT_VERSION:
return True
else:
with open(stamp_path, "w") as s:
s.write(ARTIFACT_VERSION)
return True
return False


def needs_file(filename, ctx):
out_file = ctx.allocate_file(filename).get_fs_path()
if os.path.exists(out_file):
needed = False
else:
filekey = f"{ctx.path}/{filename}"
ctx.executor.all[filekey] = None
needed = True
return needed


@entrypoint(description="Retreives a set of SDXL submodels.")
def sdxl(
model_json=cl_arg(
"model_json",
default=default_config_json,
help="Local config filepath",
),
target=cl_arg(
"target",
default="gfx942",
help="IREE target architecture.",
),
splat=cl_arg(
"splat", default=False, type=str, help="Download empty weights (for testing)"
),
):
model_params = ModelParams.load_json(model_json)
ctx = executor.BuildContext.current()
update = needs_update(ctx)

mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"

mlir_filenames = get_mlir_filenames(model_params)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
for f, url in mlir_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

vmfb_filenames = get_vmfb_filenames(model_params, target=target)
vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket)
for f, url in vmfb_urls.items():
if update or needs_file(f, ctx):
fetch_http(name=f, url=url)
params_filenames = get_params_filenames(model_params, splat)
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
Copy link
Contributor Author

@monorimet monorimet Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't very robust. Should have a md5sum checklist fetched from the bucket if downloads enabled, and compare with local checklist to determine which, if any, artifacts need updating.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hadn't yet gotten to stamp and change detection... Will in a bit.

Do you already have file hashes stored in the bucket somewhere?

Copy link
Contributor Author

@monorimet monorimet Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet. Right now, the mlir/vmfbs are always downloaded from a bucket versioned by date only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can work. You basically need something to derive a stamp value from. That can come from some part of the URL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For mammoth files, a manual version of some kind can be best anyway: it can take a long time to compute a hash of such things

Copy link
Contributor Author

@monorimet monorimet Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it seem too heavyweight to keep a md5sums.json in each bucket, and have the builder generate and keep a local set of hashes for its outputs? That way we can filter exactly what's needed before doing fetch_http? (edit: I'm pretty sure that's the same thing just more fine-grained and expensive, I suppose -- I just never liked having to download a new set of HF weights because someone added a completely unrelated file to the repo)

Copy link
Contributor

@stellaraccident stellaraccident Nov 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's the basic mechanism. We wouldn't actually compute the hash in the builder in typical use. Instead, you would tell it how to get the stamp artifact (ie. Some fixed string, a hash file, etc). If a hash file, we compute a running hash only during download and store the result, erroring if it mismatches. But just an opaque stamp value drives the up-to-date check.

It's better for everyone if such artifacts are in write once storage (ie. The same URL produces the same content for all of time). Then the stamp is just the url, and any hash checking is just for verifying the integrity of the transfer. That avoids several kinds of update race issues and it means that you can do the up to date check without network access.

if update or needs_file(f, ctx):
fetch_http(name=f, url=url)

filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames


if __name__ == "__main__":
iree_build_main()
10 changes: 9 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ class ModelParams:
# Same for VAE.
vae_batch_sizes: list[int]

# Same for scheduler.
scheduler_batch_sizes: list[int]

# Height and Width, respectively, for which Unet and VAE are compiled. e.g. [[512, 512], [1024, 1024]]
dims: list[list[int]]

# Scheduler id.
scheduler_id: str = "EulerDiscrete"

base_model_name: str = "SDXL"
# Name of the IREE module for each submodel.
clip_module_name: str = "compiled_clip"
Expand All @@ -59,11 +65,13 @@ class ModelParams:
# Classifer free guidance mode. If set to false, only positive prompts will matter.
cfg_mode = True

# DTypes (basically defaults):
# DTypes (not necessarily weights precision):
clip_dtype: sfnp.DType = sfnp.float16
unet_dtype: sfnp.DType = sfnp.float16
vae_dtype: sfnp.DType = sfnp.float16

use_i8_punet: bool = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really don't like this but the module expects fp16 I/O...

Copy link
Contributor Author

@monorimet monorimet Nov 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a specific place for IO dtype / params type is in order, but it's quite a distinction to start making over one inconsistency. One (*_dtype) is used for instructing device array creation, and the other (use_i8_punet) is used when inferring artifact names. Perhaps the filename convention should account for these cases, i.e., keep the precision spec for I/O and add a _pi8_ to denote "int8 params" or whatever fnuz924v83 datatype we need to parametrize for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop by comment, I agree with the above. We have multiple punet models to support like int8 and fp8, so it would be better to keep them separate


# ABI of the module.
module_abi_version: int = 1

Expand Down
9 changes: 7 additions & 2 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
tokenizers: list[Tokenizer],
model_params: ModelParams,
fibers_per_device: int,
workers_per_device: int = 1,
prog_isolation: str = "per_fiber",
show_progress: bool = False,
trace_execution: bool = False,
Expand All @@ -64,16 +65,20 @@ def __init__(
self.inference_programs: dict[str, sf.Program] = {}
self.trace_execution = trace_execution
self.show_progress = show_progress
self.workers_per_device = workers_per_device
self.fibers_per_device = fibers_per_device
self.prog_isolation = prog_isolations[prog_isolation]
self.workers = []
self.fibers = []
self.fiber_status = []
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.fibers_per_device):
for i in range(self.workers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
fiber = sysman.ls.create_fiber(worker, devices=[device])
self.workers.append(worker)
for i in range(self.fibers_per_device):
fiber = sysman.ls.create_fiber(
self.workers[i % len(self.workers)], devices=[device]
)
self.fibers.append(fiber)
self.fiber_status.append(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"vae_batch_sizes": [
1
],
"scheduler_batch_sizes": [
1
],
"unet_module_name": "compiled_unet",
"unet_fn_name": "run_forward",
"dims": [
Expand Down
4 changes: 4 additions & 0 deletions shortfin/python/shortfin_apps/sd/examples/sdxl_config_i8.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
"vae_batch_sizes": [
1
],
"scheduler_batch_sizes": [
1
],
"unet_dtype": "float16",
"unet_module_name": "compiled_punet",
"unet_fn_name": "main",
"use_i8_punet": true,
"dims": [
[
1024,
Expand Down
Loading
Loading