-
Notifications
You must be signed in to change notification settings - Fork 981
Add SD3 Pipeline #329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SD3 Pipeline #329
Conversation
Co-authored-by: atiorh <atiorh@users.noreply.github.com> Co-authored-by: arda-argmax <arda-argmax@users.noreply.github.com>
Thank you for opening this PR and adding support for Stable Diffusion 3! Two high level topics I think will be important to cover here are:
Reviewers have been assigned and will provide more detailed feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main concerns are:
- T5Tokenizer.swift it too long, please break out.
- Too many new public types, let's try to keep the public interface small.
- Some duplication. I know some of it needs a larger refactor, but there are some easy wins here.
public func decode( | ||
_ latents: [MLShapedArray<Float32>], | ||
scaleFactor: Float32, | ||
shiftFactor: Float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, shiftFactor
is the only difference between Decoder
and DecoderSD3
. Let's add the shift to Decoder
and default it to 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment still stands, let's reuse Decoder
instead of introducing DecoderSD3
// MARK: - Configuration files with dynamic lookup | ||
|
||
@dynamicMemberLookup | ||
public struct Config { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many types in this file are public, but I don't think they should be.
Config
seems more like a general-purpose JSON structure. Please give a better name (AnyJSON
?). An alternative is to use Codable
structs instead of this generic type.
case tooLong(String) | ||
} | ||
|
||
public protocol TokenizingModel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this protocol? Looks like it's only used once.
} | ||
} | ||
|
||
public protocol Tokenizer { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this protocol? It's only used once.
|
||
extension TokenLatticeNode { | ||
// This is a reference type because structs can't contain references to the same type | ||
// We could implement NSCopying, but frankly I don't see the point |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced this comment is accurate, or relevant. Consider converting to an init
method: init(from: TokenLattinceNode)
} | ||
} | ||
|
||
public extension Trie { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is too long. Let's move some of these types to their own files. (Also not public)
|
||
public extension Trie { | ||
// Only used for testing, could migrate to collection | ||
func get(_ element: any Sequence<T>) -> Node? { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't look like this is being used, please remove.
enum PostProcessorType: String { | ||
case TemplateProcessing | ||
case ByteLevel | ||
case RobertaProcessing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and in other enums: enum cases should start with a lowercase letter. Please review https://www.swift.org/documentation/api-design-guidelines/
} | ||
} | ||
|
||
let byteEncoder: Dictionary<UTF8.CodeUnit, String> = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider using Unicode.Scalar(_:)
or String(cString:)
, if that is not viable please use a better name for this dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @alejandro-isaza, as a general note, this T5Tokenizer code is actually copied/adapted from swift-transformers. We may want to simply add swift-transformers as a dependency, and we can bring these suggestions over there. Curious your thoughts, otherwise I will be happy to adapt this code further based on your notes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think adding a dependency is a good idea. Ideally it would be swift-tokenizers :)
return noise | ||
} | ||
|
||
func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this was copied from Unet.swift, can we refactor instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By refactor do you mean using the same function from the Unet class? Or refactor as in adjusting the Unet class to support MMDiT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor as in move the common code out into a shared function. For instance a free function that takes both the batch and the models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still pending.
Quick update here, planning on the following changes:
Estimating a week or so to complete this work and will be ready for another review at that time ⏱️. |
@alejandro-isaza @aseemw I've updated this PR with the discussed changes:
E.g. here is a script that will export all models to support 1024x1024 image output with the optional T5 text encoder included: I believe this PR is ready for a second review so let me know if you have further comments 🙏. Worth calling out here as well that I'm also testing on the macOS 15 beta, and it appears to have an issue running these models, whereas macOS 14 does not. What I'm seeing is that each step takes a very long time, with fluctuating memory that seems like it is unloading and reloading the model on each timestep during inference. For an easy way to replicate you can try out HuggingFace Diffusers app on macOS 15, which has already rolled out with these SD3 models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of places where the code could be refactored, but other than that the Swift side looks good to me.
public func decode( | ||
_ latents: [MLShapedArray<Float32>], | ||
scaleFactor: Float32, | ||
shiftFactor: Float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment still stands, let's reuse Decoder
instead of introducing DecoderSD3
return noise | ||
} | ||
|
||
func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still pending.
@alejandro-isaza Thanks for the note, just pushed updates for these parts. |
* Add sd3 pipeline to app * Revert to ml-stable-diffusion main branch - apple/ml-stable-diffusion#329 merged to main --------- Co-authored-by: Graham Bing <gdbing@users.noreply.github.com>
SD3 on Core ML 🎉
Brought to Apple Silicon by your friends at @argmaxinc
Paper: https://stability.ai/news/stable-diffusion-3-research-paper
What's new:
How to use it:
For the models that didn't change, the existing conversion pipelines should all work as is:
We also created an entire repo dedicated to the new models called DiffusionKit and comes with conversion pipelines for the new VAE and MMDiT models
To install:
Convert MMDiT:
Convert VAE:
Finally, combine all of these models into the same folder and point this CLI to the path they are in to test it out with the new cli flag
--sd3
:You should see a new image in your output-dir that might look something like this:
Try it out today via this PR into Huggingface's excellent swift-coreml-diffusers app (included pre-converted models and pipeline usage example)