16
16
DiffusionPipeline ,
17
17
ControlNetModel
18
18
)
19
+ from diffusionkit .tests .torch2coreml import (
20
+ convert_mmdit_to_mlpackage ,
21
+ convert_vae_to_mlpackage
22
+ )
19
23
import gc
24
+ from huggingface_hub import snapshot_download
20
25
21
26
import logging
22
27
@@ -207,6 +212,26 @@ def _compile_coreml_model(source_model_path, output_dir, final_name):
207
212
return target_path
208
213
209
214
215
+ def _download_t5_model (args , t5_save_path ):
216
+ t5_url = args .text_encoder_t5_url
217
+ match = re .match (r'https://huggingface.co/(.+)/resolve/main/(.+)' , t5_url )
218
+ if not match :
219
+ raise ValueError (f"Invalid Hugging Face URL: { t5_url } " )
220
+ repo_id , model_subpath = match .groups ()
221
+
222
+ download_path = snapshot_download (
223
+ repo_id = repo_id ,
224
+ revision = "main" ,
225
+ allow_patterns = [f"{ model_subpath } /*" ]
226
+ )
227
+ logger .info (f"Downloaded T5 model to { download_path } " )
228
+
229
+ # Move the downloaded model to the top level of the Resources directory
230
+ logger .info (f"Copying T5 model from { download_path } to { t5_save_path } " )
231
+ cache_path = os .path .join (download_path , model_subpath )
232
+ shutil .copytree (cache_path , t5_save_path )
233
+
234
+
210
235
def bundle_resources_for_swift_cli (args ):
211
236
"""
212
237
- Compiles Core ML models from mlpackage into mlmodelc format
@@ -228,6 +253,7 @@ def bundle_resources_for_swift_cli(args):
228
253
("refiner" , "UnetRefiner" ),
229
254
("refiner_chunk1" , "UnetRefinerChunk1" ),
230
255
("refiner_chunk2" , "UnetRefinerChunk2" ),
256
+ ("mmdit" , "MultiModalDiffusionTransformer" ),
231
257
("control-unet" , "ControlledUnet" ),
232
258
("control-unet_chunk1" , "ControlledUnetChunk1" ),
233
259
("control-unet_chunk2" , "ControlledUnetChunk2" ),
@@ -241,7 +267,7 @@ def bundle_resources_for_swift_cli(args):
241
267
logger .warning (
242
268
f"{ source_path } not found, skipping compilation to { target_name } .mlmodelc"
243
269
)
244
-
270
+
245
271
if args .convert_controlnet :
246
272
for controlnet_model_version in args .convert_controlnet :
247
273
controlnet_model_name = controlnet_model_version .replace ("/" , "_" )
@@ -271,6 +297,25 @@ def bundle_resources_for_swift_cli(args):
271
297
f .write (requests .get (args .text_encoder_merges_url ).content )
272
298
logger .info ("Done" )
273
299
300
+ # Fetch and save pre-converted T5 text encoder model
301
+ t5_model_name = "TextEncoderT5.mlmodelc"
302
+ t5_save_path = os .path .join (resources_dir , t5_model_name )
303
+ if args .include_t5 :
304
+ if not os .path .exists (t5_save_path ):
305
+ logger .info ("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc" )
306
+ _download_t5_model (args , t5_save_path )
307
+ logger .info ("Done" )
308
+ else :
309
+ logger .info (f"Skipping T5 download as { t5_save_path } already exists" )
310
+
311
+ # Fetch and save T5 text tokenizer JSON files
312
+ logger .info ("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json" )
313
+ with open (os .path .join (resources_dir , "tokenizer_config.json" ), "wb" ) as f :
314
+ f .write (requests .get (args .text_encoder_t5_config_url ).content )
315
+ with open (os .path .join (resources_dir , "tokenizer.json" ), "wb" ) as f :
316
+ f .write (requests .get (args .text_encoder_t5_data_url ).content )
317
+ logger .info ("Done" )
318
+
274
319
return resources_dir
275
320
276
321
@@ -557,6 +602,61 @@ def forward(self, z):
557
602
del traced_vae_decoder , pipe .vae .decoder , coreml_vae_decoder
558
603
gc .collect ()
559
604
605
+ def convert_vae_decoder_sd3 (args ):
606
+ """ Converts the VAE component of Stable Diffusion 3
607
+ """
608
+ out_path = _get_out_path (args , "vae_decoder" )
609
+ if os .path .exists (out_path ):
610
+ logger .info (
611
+ f"`vae_decoder` already exists at { out_path } , skipping conversion."
612
+ )
613
+ return
614
+
615
+ # Convert the VAE Decoder model via DiffusionKit
616
+ converted_vae_path = convert_vae_to_mlpackage (
617
+ model_version = args .model_version ,
618
+ latent_h = args .latent_h ,
619
+ latent_w = args .latent_w ,
620
+ output_dir = args .o ,
621
+ )
622
+
623
+ # Load converted model
624
+ coreml_vae_decoder = ct .models .MLModel (converted_vae_path )
625
+
626
+ # Set model metadata
627
+ coreml_vae_decoder .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
628
+ coreml_vae_decoder .license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
629
+ coreml_vae_decoder .version = args .model_version
630
+ coreml_vae_decodershort_description = \
631
+ "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
632
+ "Please refer to https://arxiv.org/pdf/2403.03206 for details."
633
+
634
+ # Set the input descriptions
635
+ coreml_vae_decoder .input_description ["z" ] = \
636
+ "The denoised latent embeddings from the unet model after the last step of reverse diffusion"
637
+
638
+ # Set the output descriptions
639
+ coreml_vae_decoder .output_description [
640
+ "image" ] = "Generated image normalized to range [-1, 1]"
641
+
642
+ # Set package version metadata
643
+ from python_coreml_stable_diffusion ._version import __version__
644
+ coreml_vae_decoder .user_defined_metadata ["com.github.apple.ml-stable-diffusion.version" ] = __version__
645
+ from diffusionkit .version import __version__
646
+ coreml_vae_decoder .user_defined_metadata ["com.github.argmax.diffusionkit.version" ] = __version__
647
+
648
+ # Save the updated model
649
+ coreml_vae_decoder .save (out_path )
650
+
651
+ logger .info (f"Saved vae_decoder into { out_path } " )
652
+
653
+ # Delete the original file
654
+ if os .path .exists (converted_vae_path ):
655
+ shutil .rmtree (converted_vae_path )
656
+
657
+ del coreml_vae_decoder
658
+ gc .collect ()
659
+
560
660
561
661
def convert_vae_encoder (pipe , args ):
562
662
""" Converts the VAE Encoder component of Stable Diffusion
@@ -909,6 +1009,72 @@ def convert_unet(pipe, args, model_name = None):
909
1009
chunk_mlprogram .main (args )
910
1010
911
1011
1012
+ def convert_mmdit (args ):
1013
+ """ Converts the MMDiT component of Stable Diffusion 3
1014
+ """
1015
+ out_path = _get_out_path (args , "mmdit" )
1016
+ if os .path .exists (out_path ):
1017
+ logger .info (
1018
+ f"`mmdit` already exists at { out_path } , skipping conversion."
1019
+ )
1020
+ return
1021
+
1022
+ # Convert the MMDiT model via DiffusionKit
1023
+ converted_mmdit_path = convert_mmdit_to_mlpackage (
1024
+ model_version = args .model_version ,
1025
+ latent_h = args .latent_h ,
1026
+ latent_w = args .latent_w ,
1027
+ output_dir = args .o ,
1028
+ # FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32
1029
+ compute_precision = ct .precision .FLOAT32 ,
1030
+ compute_unit = ct .ComputeUnit .CPU_AND_GPU ,
1031
+ )
1032
+
1033
+ # Load converted model
1034
+ coreml_mmdit = ct .models .MLModel (converted_mmdit_path )
1035
+
1036
+ # Set model metadata
1037
+ coreml_mmdit .author = f"Please refer to the Model Card available at huggingface.co/{ args .model_version } "
1038
+ coreml_mmdit .license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)"
1039
+ coreml_mmdit .version = args .model_version
1040
+ coreml_mmdit .short_description = \
1041
+ "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \
1042
+ "Please refer to https://arxiv.org/pdf/2403.03206 for details."
1043
+
1044
+ # Set the input descriptions
1045
+ coreml_mmdit .input_description ["latent_image_embeddings" ] = \
1046
+ "The low resolution latent feature maps being denoised through reverse diffusion"
1047
+ coreml_mmdit .input_description ["token_level_text_embeddings" ] = \
1048
+ "Output embeddings from the associated text_encoder model to condition to generated image on text. " \
1049
+ "A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. "
1050
+ coreml_mmdit .input_description ["pooled_text_embeddings" ] = \
1051
+ "Additional embeddings that if specified are added to the embeddings that are passed along to the MMDiT model."
1052
+ coreml_mmdit .input_description ["timestep" ] = \
1053
+ "A value emitted by the associated scheduler object to condition the model on a given noise schedule"
1054
+
1055
+ # Set the output descriptions
1056
+ coreml_mmdit .output_description ["denoiser_output" ] = \
1057
+ "Same shape and dtype as the `latent_image_embeddings` input. " \
1058
+ "The predicted noise to facilitate the reverse diffusion (denoising) process"
1059
+
1060
+ # Set package version metadata
1061
+ from python_coreml_stable_diffusion ._version import __version__
1062
+ coreml_mmdit .user_defined_metadata ["com.github.apple.ml-stable-diffusion.version" ] = __version__
1063
+ from diffusionkit .version import __version__
1064
+ coreml_mmdit .user_defined_metadata ["com.github.argmax.diffusionkit.version" ] = __version__
1065
+
1066
+ # Save the updated model
1067
+ coreml_mmdit .save (out_path )
1068
+
1069
+ logger .info (f"Saved vae_decoder into { out_path } " )
1070
+
1071
+ # Delete the original file
1072
+ if os .path .exists (converted_mmdit_path ):
1073
+ shutil .rmtree (converted_mmdit_path )
1074
+
1075
+ del coreml_mmdit
1076
+ gc .collect ()
1077
+
912
1078
def convert_safety_checker (pipe , args ):
913
1079
""" Converts the Safety Checker component of Stable Diffusion
914
1080
"""
@@ -1288,6 +1454,16 @@ def get_pipeline(args):
1288
1454
use_safetensors = True ,
1289
1455
vae = vae ,
1290
1456
use_auth_token = True )
1457
+ elif args .sd3_version :
1458
+ # SD3 uses standard SDXL diffusers pipeline besides the vae, denoiser, and T5 text encoder
1459
+ sdxl_base_version = "stabilityai/stable-diffusion-xl-base-1.0"
1460
+ args .xl_version = True
1461
+ logger .info (f"SD3 version specified, initializing DiffusionPipeline with { sdxl_base_version } for non-SD3 components.." )
1462
+ pipe = DiffusionPipeline .from_pretrained (sdxl_base_version ,
1463
+ torch_dtype = torch .float16 ,
1464
+ variant = "fp16" ,
1465
+ use_safetensors = True ,
1466
+ use_auth_token = True )
1291
1467
else :
1292
1468
pipe = DiffusionPipeline .from_pretrained (model_version ,
1293
1469
torch_dtype = torch .float16 ,
@@ -1316,7 +1492,10 @@ def main(args):
1316
1492
# Convert models
1317
1493
if args .convert_vae_decoder :
1318
1494
logger .info ("Converting vae_decoder" )
1319
- convert_vae_decoder (pipe , args )
1495
+ if args .sd3_version :
1496
+ convert_vae_decoder_sd3 (args )
1497
+ else :
1498
+ convert_vae_decoder (pipe , args )
1320
1499
logger .info ("Converted vae_decoder" )
1321
1500
1322
1501
if args .convert_vae_encoder :
@@ -1363,6 +1542,11 @@ def main(args):
1363
1542
del pipe
1364
1543
gc .collect ()
1365
1544
logger .info (f"Converted refiner" )
1545
+
1546
+ if args .convert_mmdit :
1547
+ logger .info ("Converting mmdit" )
1548
+ convert_mmdit (args )
1549
+ logger .info ("Converted mmdit" )
1366
1550
1367
1551
if args .quantize_nbits is not None :
1368
1552
logger .info (f"Quantizing weights to { args .quantize_nbits } -bit precision" )
@@ -1383,6 +1567,7 @@ def parser_spec():
1383
1567
parser .add_argument ("--convert-vae-decoder" , action = "store_true" )
1384
1568
parser .add_argument ("--convert-vae-encoder" , action = "store_true" )
1385
1569
parser .add_argument ("--convert-unet" , action = "store_true" )
1570
+ parser .add_argument ("--convert-mmdit" , action = "store_true" )
1386
1571
parser .add_argument ("--convert-safety-checker" , action = "store_true" )
1387
1572
parser .add_argument (
1388
1573
"--convert-controlnet" ,
@@ -1489,6 +1674,7 @@ def parser_spec():
1489
1674
"If specified, enable unet to receive additional inputs from controlnet. "
1490
1675
"Each input added to corresponding resnet output."
1491
1676
)
1677
+ parser .add_argument ("--include-t5" , action = "store_true" )
1492
1678
1493
1679
# Swift CLI Resource Bundling
1494
1680
parser .add_argument (
@@ -1508,11 +1694,30 @@ def parser_spec():
1508
1694
default =
1509
1695
"https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt" ,
1510
1696
help = "The URL to the merged pairs used in by the text tokenizer." )
1697
+ parser .add_argument (
1698
+ "--text-encoder-t5-url" ,
1699
+ default =
1700
+ "https://huggingface.co/argmaxinc/coreml-stable-diffusion-3-medium/resolve/main/TextEncoderT5.mlmodelc" ,
1701
+ help = "The URL to the pre-converted T5 encoder model." )
1702
+ parser .add_argument (
1703
+ "--text-encoder-t5-config-url" ,
1704
+ default =
1705
+ "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer_config.json" ,
1706
+ help = "The URL to the merged pairs used in by the text tokenizer." )
1707
+ parser .add_argument (
1708
+ "--text-encoder-t5-data-url" ,
1709
+ default =
1710
+ "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json" ,
1711
+ help = "The URL to the merged pairs used in by the text tokenizer." )
1511
1712
parser .add_argument (
1512
1713
"--xl-version" ,
1513
1714
action = "store_true" ,
1514
1715
help = ("If specified, the pre-trained model will be treated as an instantiation of "
1515
1716
"`diffusers.pipelines.StableDiffusionXLPipeline` instead of `diffusers.pipelines.StableDiffusionPipeline`" ))
1717
+ parser .add_argument (
1718
+ "--sd3-version" ,
1719
+ action = "store_true" ,
1720
+ help = ("If specified, the pre-trained model will be treated as an SD3 model." ))
1516
1721
1517
1722
return parser
1518
1723
0 commit comments