-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(shortfin-sd) Adds iree.build artifact fetching. #411
Conversation
clip_dtype: sfnp.DType = sfnp.float16 | ||
unet_dtype: sfnp.DType = sfnp.float16 | ||
vae_dtype: sfnp.DType = sfnp.float16 | ||
|
||
use_i8_punet: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
really don't like this but the module expects fp16 I/O...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a specific place for IO dtype / params type is in order, but it's quite a distinction to start making over one inconsistency. One (*_dtype
) is used for instructing device array creation, and the other (use_i8_punet
) is used when inferring artifact names. Perhaps the filename convention should account for these cases, i.e., keep the precision spec for I/O and add a _pi8_
to denote "int8 params" or whatever fnuz924v83 datatype we need to parametrize for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Drop by comment, I agree with the above. We have multiple punet
models to support like int8
and fp8
, so it would be better to keep them separate
parent = os.path.dirname(this_dir) | ||
default_config_json = os.path.join(parent, "examples", "sdxl_config_i8.json") | ||
|
||
dtype_to_filetag = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we rename the artifacts to match sfnp.Dtype attributes instead of doing little workarounds like this for old naming conventions. Once the exports are spinning and publishing regularly we can make changes with control..
NOTE: needs rebase on main once #413 lands |
df9a0fd
to
f073eae
Compare
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) | ||
ctx = executor.BuildContext.current() | ||
for f, url in params_urls.items(): | ||
out_file = os.path.join(ctx.executor.output_dir, f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't very robust. Should have a md5sum checklist fetched from the bucket if downloads enabled, and compare with local checklist to determine which, if any, artifacts need updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I hadn't yet gotten to stamp and change detection... Will in a bit.
Do you already have file hashes stored in the bucket somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet. Right now, the mlir/vmfbs are always downloaded from a bucket versioned by date only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That can work. You basically need something to derive a stamp value from. That can come from some part of the URL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For mammoth files, a manual version of some kind can be best anyway: it can take a long time to compute a hash of such things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it seem too heavyweight to keep a md5sums.json in each bucket, and have the builder generate and keep a local set of hashes for its outputs? That way we can filter exactly what's needed before doing fetch_http? (edit: I'm pretty sure that's the same thing just more fine-grained and expensive, I suppose -- I just never liked having to download a new set of HF weights because someone added a completely unrelated file to the repo)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's the basic mechanism. We wouldn't actually compute the hash in the builder in typical use. Instead, you would tell it how to get the stamp artifact (ie. Some fixed string, a hash file, etc). If a hash file, we compute a running hash only during download and store the result, erroring if it mismatches. But just an opaque stamp value drives the up-to-date check.
It's better for everyone if such artifacts are in write once storage (ie. The same URL produces the same content for all of time). Then the stamp is just the url, and any hash checking is just for verifying the integrity of the transfer. That avoids several kinds of update race issues and it means that you can do the up to date check without network access.
Also adds two slow tests for testing larger SDXL server loads that will not trigger in any workflows yet.
This is missing a few things:
Builder isn't very smart yet, and will just blindly try to download composed filenames from sharkpublic buckets.
It should eventually cover: