Skip to content

Commit c891f43

Browse files
ZachNagengastatiorharda-argmax
authored
Add SD3 Pipeline (#329)
* Add SD3 Pipeline Co-authored-by: atiorh <atiorh@users.noreply.github.com> Co-authored-by: arda-argmax <arda-argmax@users.noreply.github.com> * Use swift-transformers for tokenization * Use diffusionkit converters in torch2coreml * Documentation and cleanup * Add model link * Consolidate batch prediction logic * Remove DecoderSD3.swift and consolidate logic into Decoder.swift * Remove DiffusionKit MLX inference reference from README --------- Co-authored-by: atiorh <atiorh@users.noreply.github.com> Co-authored-by: arda-argmax <arda-argmax@users.noreply.github.com> Co-authored-by: atila <atiorh@icloud.com>
1 parent 5a170d2 commit c891f43

17 files changed

+1326
-59
lines changed

Package.swift

+7-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import PackageDescription
66
let package = Package(
77
name: "stable-diffusion",
88
platforms: [
9-
.macOS(.v11),
10-
.iOS(.v14),
9+
.macOS(.v13),
10+
.iOS(.v16),
1111
],
1212
products: [
1313
.library(
@@ -18,12 +18,15 @@ let package = Package(
1818
targets: ["StableDiffusionCLI"])
1919
],
2020
dependencies: [
21-
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3")
21+
.package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3"),
22+
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"),
2223
],
2324
targets: [
2425
.target(
2526
name: "StableDiffusion",
26-
dependencies: [],
27+
dependencies: [
28+
.product(name: "Transformers", package: "swift-transformers"),
29+
],
2730
path: "swift/StableDiffusion"),
2831
.executableTarget(
2932
name: "StableDiffusionCLI",

README.md

+47
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,52 @@ An example `<selected-recipe-string-key>` would be `"recipe_4.50_bit_mixedpalett
246246

247247
</details>
248248

249+
250+
## <a name="using-stable-diffusion-3"></a> Using Stable Diffusion 3
251+
252+
<details>
253+
<summary> Details (Click to expand) </summary>
254+
255+
### Model Conversion
256+
257+
Stable Diffusion 3 uses some new and some old models to run. For the text encoders, the conversion can be done using a similar command as before with the `--sd3-version` flag.
258+
259+
```bash
260+
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --sd3-version -o <output-dir>
261+
```
262+
263+
For the new models (MMDiT, a new VAE with 16 channels, and the T5 text encoder), there are a number of new CLI flags that utilize the [DiffusionKit](https://www.github.com/argmaxinc/DiffusionKit) repo:
264+
265+
- `--sd3-version`: Indicates to the converter to treat this as a Stable Diffusion 3 model
266+
- `--convert-mmdit`: Convert the MMDiT model
267+
- `--convert-vae-decoder`: Convert the new VAE model (this will use the 16 channel version if --sd3-version is set)
268+
- `--include-t5`: Downloads and includes a pre-converted T5 text encoder in the conversion
269+
270+
e.g.:
271+
```bash
272+
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version -o <output-dir>
273+
```
274+
275+
To convert the full pipeline with at 1024x1024 resolution, the following command may be used:
276+
277+
```bash
278+
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version --latent-h 128 --latent-w 128 -o <output-dir>
279+
```
280+
281+
Keep in mind that the MMDiT model is quite large and will require increasingly more memory and time to convert as the latent resolution increases.
282+
283+
Also note that currently the MMDiT model requires fp32 and therefore only supports `CPU_AND_GPU` compute units and `ORIGINAL` attention implementation (the default for this pipeline).
284+
285+
### Swift Inference
286+
287+
Swift inference for Stable Diffusion 3 is similar to the previous versions. The only difference is that the `--sd3` flag should be used to indicate that the model is a Stable Diffusion 3 model.
288+
289+
```bash
290+
swift run StableDiffusionSample <prompt> --resource-path <output-mlpackages-directory/Resources> --output-path <output-dir> --compute-units cpuAndGPU --sd3
291+
```
292+
293+
</details>
294+
249295
## <a name="using-stable-diffusion-xl"></a> Using Stable Diffusion XL
250296

251297
<details>
@@ -356,6 +402,7 @@ Resources:
356402
- [`stabilityai/stable-diffusion-2-1-base`](https://huggingface.co/apple/coreml-stable-diffusion-2-1-base)
357403
- [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base)
358404
- [`stabilityai/stable-diffusion-xl-{base+refiner}-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base-with-refiner)
405+
- [`stabilityai/stable-diffusion-3-medium`](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
359406

360407
If you want to use any of those models you may download the weights and proceed to [generate images with Python](#image-generation-with-python) or [Swift](#image-generation-with-swift).
361408

python_coreml_stable_diffusion/torch2coreml.py

+207-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
DiffusionPipeline,
1717
ControlNetModel
1818
)
19+
from diffusionkit.tests.torch2coreml import (
20+
convert_mmdit_to_mlpackage,
21+
convert_vae_to_mlpackage
22+
)
1923
import gc
24+
from huggingface_hub import snapshot_download
2025

2126
import logging
2227

@@ -207,6 +212,26 @@ def _compile_coreml_model(source_model_path, output_dir, final_name):
207212
return target_path
208213

209214

215+
def _download_t5_model(args, t5_save_path):
216+
t5_url = args.text_encoder_t5_url
217+
match = re.match(r'https://huggingface.co/(.+)/resolve/main/(.+)', t5_url)
218+
if not match:
219+
raise ValueError(f"Invalid Hugging Face URL: {t5_url}")
220+
repo_id, model_subpath = match.groups()
221+
222+
download_path = snapshot_download(
223+
repo_id=repo_id,
224+
revision="main",
225+
allow_patterns=[f"{model_subpath}/*"]
226+
)
227+
logger.info(f"Downloaded T5 model to {download_path}")
228+
229+
# Move the downloaded model to the top level of the Resources directory
230+
logger.info(f"Copying T5 model from {download_path} to {t5_save_path}")
231+
cache_path = os.path.join(download_path, model_subpath)
232+
shutil.copytree(cache_path, t5_save_path)
233+
234+
210235
def bundle_resources_for_swift_cli(args):
211236
"""
212237
- Compiles Core ML models from mlpackage into mlmodelc format
@@ -228,6 +253,7 @@ def bundle_resources_for_swift_cli(args):
228253
("refiner", "UnetRefiner"),
229254
("refiner_chunk1", "UnetRefinerChunk1"),
230255
("refiner_chunk2", "UnetRefinerChunk2"),
256+
("mmdit", "MultiModalDiffusionTransformer"),
231257
("control-unet", "ControlledUnet"),
232258
("control-unet_chunk1", "ControlledUnetChunk1"),
233259
("control-unet_chunk2", "ControlledUnetChunk2"),
@@ -241,7 +267,7 @@ def bundle_resources_for_swift_cli(args):
241267
logger.warning(
242268
f"{source_path} not found, skipping compilation to {target_name}.mlmodelc"
243269
)
244-
270+
245271
if args.convert_controlnet:
246272
for controlnet_model_version in args.convert_controlnet:
247273
controlnet_model_name = controlnet_model_version.replace("/", "_")
@@ -271,6 +297,25 @@ def bundle_resources_for_swift_cli(args):
271297
f.write(requests.get(args.text_encoder_merges_url).content)
272298
logger.info("Done")
273299

300+
# Fetch and save pre-converted T5 text encoder model
301+
t5_model_name = "TextEncoderT5.mlmodelc"
302+
t5_save_path = os.path.join(resources_dir, t5_model_name)
303+
if args.include_t5:
304+
if not os.path.exists(t5_save_path):
305+
logger.info("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc")
306+
_download_t5_model(args, t5_save_path)
307+
logger.info("Done")
308+
else:
309+
logger.info(f"Skipping T5 download as {t5_save_path} already exists")
310+
311+
# Fetch and save T5 text tokenizer JSON files
312+
logger.info("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json")
313+
with open(os.path.join(resources_dir, "tokenizer_config.json"), "wb") as f:
314+
f.write(requests.get(args.text_encoder_t5_config_url).content)
315+
with open(os.path.join(resources_dir, "tokenizer.json"), "wb") as f:
316+
f.write(requests.get(args.text_encoder_t5_data_url).content)
317+
logger.info("Done")
318+
274319
return resources_dir
275320

276321

@@ -557,6 +602,61 @@ def forward(self, z):
557602
del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder
558603
gc.collect()
559604

605+
def convert_vae_decoder_sd3(args):
606+
""" Converts the VAE component of Stable Diffusion 3
607+
"""
608+
out_path = _get_out_path(args, "vae_decoder")
609+
if os.path.exists(out_path):
610+
logger.info(
611+
f"`vae_decoder` already exists at {out_path}, skipping conversion."
612+
)
613+
return
614+
615+
# Convert the VAE Decoder model via DiffusionKit
616+
converted_vae_path = convert_vae_to_mlpackage(
617+
model_version=args.model_version,
618+
latent_h=args.latent_h,
619+
latent_w=args.latent_w,
620+
output_dir=args.o,
621+
)
622+
623+
# Load converted model
624+
coreml_vae_decoder = ct.models.MLModel(converted_vae_path)
625+
626+
# Set model metadata
627+
coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
628+
coreml_vae_decoder.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
629+
coreml_vae_decoder.version = args.model_version
630+
coreml_vae_decodershort_description = \
631+
"Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
632+
"Please refer to https://arxiv.org/pdf/2403.03206 for details."
633+
634+
# Set the input descriptions
635+
coreml_vae_decoder.input_description["z"] = \
636+
"The denoised latent embeddings from the unet model after the last step of reverse diffusion"
637+
638+
# Set the output descriptions
639+
coreml_vae_decoder.output_description[
640+
"image"] = "Generated image normalized to range [-1, 1]"
641+
642+
# Set package version metadata
643+
from python_coreml_stable_diffusion._version import __version__
644+
coreml_vae_decoder.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
645+
from diffusionkit.version import __version__
646+
coreml_vae_decoder.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__
647+
648+
# Save the updated model
649+
coreml_vae_decoder.save(out_path)
650+
651+
logger.info(f"Saved vae_decoder into {out_path}")
652+
653+
# Delete the original file
654+
if os.path.exists(converted_vae_path):
655+
shutil.rmtree(converted_vae_path)
656+
657+
del coreml_vae_decoder
658+
gc.collect()
659+
560660

561661
def convert_vae_encoder(pipe, args):
562662
""" Converts the VAE Encoder component of Stable Diffusion
@@ -909,6 +1009,72 @@ def convert_unet(pipe, args, model_name = None):
9091009
chunk_mlprogram.main(args)
9101010

9111011

1012+
def convert_mmdit(args):
1013+
""" Converts the MMDiT component of Stable Diffusion 3
1014+
"""
1015+
out_path = _get_out_path(args, "mmdit")
1016+
if os.path.exists(out_path):
1017+
logger.info(
1018+
f"`mmdit` already exists at {out_path}, skipping conversion."
1019+
)
1020+
return
1021+
1022+
# Convert the MMDiT model via DiffusionKit
1023+
converted_mmdit_path = convert_mmdit_to_mlpackage(
1024+
model_version=args.model_version,
1025+
latent_h=args.latent_h,
1026+
latent_w=args.latent_w,
1027+
output_dir=args.o,
1028+
# FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
1029+
compute_precision=ct.precision.FLOAT32,
1030+
compute_unit=ct.ComputeUnit.CPU_AND_GPU,
1031+
)
1032+
1033+
# Load converted model
1034+
coreml_mmdit = ct.models.MLModel(converted_mmdit_path)
1035+
1036+
# Set model metadata
1037+
coreml_mmdit.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}"
1038+
coreml_mmdit.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
1039+
coreml_mmdit.version = args.model_version
1040+
coreml_mmdit.short_description = \
1041+
"Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
1042+
"Please refer to https://arxiv.org/pdf/2403.03206 for details."
1043+
1044+
# Set the input descriptions
1045+
coreml_mmdit.input_description["latent_image_embeddings"] = \
1046+
"The low resolution latent feature maps being denoised through reverse diffusion"
1047+
coreml_mmdit.input_description["token_level_text_embeddings"] = \
1048+
"Output embeddings from the associated text_encoder model to condition to generated image on text. " \
1049+
"A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. "
1050+
coreml_mmdit.input_description["pooled_text_embeddings"] = \
1051+
"Additional embeddings that if specified are added to the embeddings that are passed along to the MMDiT model."
1052+
coreml_mmdit.input_description["timestep"] = \
1053+
"A value emitted by the associated scheduler object to condition the model on a given noise schedule"
1054+
1055+
# Set the output descriptions
1056+
coreml_mmdit.output_description["denoiser_output"] = \
1057+
"Same shape and dtype as the `latent_image_embeddings` input. " \
1058+
"The predicted noise to facilitate the reverse diffusion (denoising) process"
1059+
1060+
# Set package version metadata
1061+
from python_coreml_stable_diffusion._version import __version__
1062+
coreml_mmdit.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__
1063+
from diffusionkit.version import __version__
1064+
coreml_mmdit.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__
1065+
1066+
# Save the updated model
1067+
coreml_mmdit.save(out_path)
1068+
1069+
logger.info(f"Saved vae_decoder into {out_path}")
1070+
1071+
# Delete the original file
1072+
if os.path.exists(converted_mmdit_path):
1073+
shutil.rmtree(converted_mmdit_path)
1074+
1075+
del coreml_mmdit
1076+
gc.collect()
1077+
9121078
def convert_safety_checker(pipe, args):
9131079
""" Converts the Safety Checker component of Stable Diffusion
9141080
"""
@@ -1288,6 +1454,16 @@ def get_pipeline(args):
12881454
use_safetensors=True,
12891455
vae=vae,
12901456
use_auth_token=True)
1457+
elif args.sd3_version:
1458+
# SD3 uses standard SDXL diffusers pipeline besides the vae, denoiser, and T5 text encoder
1459+
sdxl_base_version = "stabilityai/stable-diffusion-xl-base-1.0"
1460+
args.xl_version = True
1461+
logger.info(f"SD3 version specified, initializing DiffusionPipeline with {sdxl_base_version} for non-SD3 components..")
1462+
pipe = DiffusionPipeline.from_pretrained(sdxl_base_version,
1463+
torch_dtype=torch.float16,
1464+
variant="fp16",
1465+
use_safetensors=True,
1466+
use_auth_token=True)
12911467
else:
12921468
pipe = DiffusionPipeline.from_pretrained(model_version,
12931469
torch_dtype=torch.float16,
@@ -1316,7 +1492,10 @@ def main(args):
13161492
# Convert models
13171493
if args.convert_vae_decoder:
13181494
logger.info("Converting vae_decoder")
1319-
convert_vae_decoder(pipe, args)
1495+
if args.sd3_version:
1496+
convert_vae_decoder_sd3(args)
1497+
else:
1498+
convert_vae_decoder(pipe, args)
13201499
logger.info("Converted vae_decoder")
13211500

13221501
if args.convert_vae_encoder:
@@ -1363,6 +1542,11 @@ def main(args):
13631542
del pipe
13641543
gc.collect()
13651544
logger.info(f"Converted refiner")
1545+
1546+
if args.convert_mmdit:
1547+
logger.info("Converting mmdit")
1548+
convert_mmdit(args)
1549+
logger.info("Converted mmdit")
13661550

13671551
if args.quantize_nbits is not None:
13681552
logger.info(f"Quantizing weights to {args.quantize_nbits}-bit precision")
@@ -1383,6 +1567,7 @@ def parser_spec():
13831567
parser.add_argument("--convert-vae-decoder", action="store_true")
13841568
parser.add_argument("--convert-vae-encoder", action="store_true")
13851569
parser.add_argument("--convert-unet", action="store_true")
1570+
parser.add_argument("--convert-mmdit", action="store_true")
13861571
parser.add_argument("--convert-safety-checker", action="store_true")
13871572
parser.add_argument(
13881573
"--convert-controlnet",
@@ -1489,6 +1674,7 @@ def parser_spec():
14891674
"If specified, enable unet to receive additional inputs from controlnet. "
14901675
"Each input added to corresponding resnet output."
14911676
)
1677+
parser.add_argument("--include-t5", action="store_true")
14921678

14931679
# Swift CLI Resource Bundling
14941680
parser.add_argument(
@@ -1508,11 +1694,30 @@ def parser_spec():
15081694
default=
15091695
"https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt",
15101696
help="The URL to the merged pairs used in by the text tokenizer.")
1697+
parser.add_argument(
1698+
"--text-encoder-t5-url",
1699+
default=
1700+
"https://huggingface.co/argmaxinc/coreml-stable-diffusion-3-medium/resolve/main/TextEncoderT5.mlmodelc",
1701+
help="The URL to the pre-converted T5 encoder model.")
1702+
parser.add_argument(
1703+
"--text-encoder-t5-config-url",
1704+
default=
1705+
"https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer_config.json",
1706+
help="The URL to the merged pairs used in by the text tokenizer.")
1707+
parser.add_argument(
1708+
"--text-encoder-t5-data-url",
1709+
default=
1710+
"https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json",
1711+
help="The URL to the merged pairs used in by the text tokenizer.")
15111712
parser.add_argument(
15121713
"--xl-version",
15131714
action="store_true",
15141715
help=("If specified, the pre-trained model will be treated as an instantiation of "
15151716
"`diffusers.pipelines.StableDiffusionXLPipeline` instead of `diffusers.pipelines.StableDiffusionPipeline`"))
1717+
parser.add_argument(
1718+
"--sd3-version",
1719+
action="store_true",
1720+
help=("If specified, the pre-trained model will be treated as an SD3 model."))
15161721

15171722
return parser
15181723

0 commit comments

Comments
 (0)