Skip to content

Add SD3 Pipeline #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 23, 2024
Merged

Add SD3 Pipeline #329

merged 8 commits into from
Jul 23, 2024

Conversation

ZachNagengast
Copy link
Contributor

@ZachNagengast ZachNagengast commented Jun 12, 2024

SD3 on Core ML 🎉

Brought to Apple Silicon by your friends at @argmaxinc

Paper: https://stability.ai/news/stable-diffusion-3-research-paper

What's new:

  • StableDiffusion3Pipeline
    • Main entry point with standard protocol usage
  • MultiModalDiffusionTransformer (MMDiT)
    • The latest and greatest in diffusion technology from StabilityAI, it utilizes a new architecture and several new supporting models
  • TextEncoderT5
  • DecoderSD3
    • This new VAE has 16 channels, up from 4 with previous models
  • DiscreteFlowScheduler
    • A new scheduler that uses shifting to achieve better denoising at high resolutions

How to use it:

For the models that didn't change, the existing conversion pipelines should all work as is:

python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder --xl-version --model-version stabilityai/stable-diffusion-xl-base-1.0 --bundle-resources-for-swift-cli --attention-implementation ORIGINAL -o <output-dir>

We also created an entire repo dedicated to the new models called DiffusionKit and comes with conversion pipelines for the new VAE and MMDiT models

To install:

git clone https://github.com/argmaxinc/DiffusionKit.git
cd DiffusionKit
pip install -e .

Convert MMDiT:

python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}

Convert VAE:

python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> -o <output-mlpackages-directory> --latent-size {64, 128}

Finally, combine all of these models into the same folder and point this CLI to the path they are in to test it out with the new cli flag --sd3:

swift run StableDiffusionSample <prompt> --resource-path <output-mlmodelc-directory/Resources> --output-path <output-dir> --sd3

You should see a new image in your output-dir that might look something like this:

image

Try it out today via this PR into Huggingface's excellent swift-coreml-diffusers app (included pre-converted models and pipeline usage example)

Co-authored-by: atiorh <atiorh@users.noreply.github.com>
Co-authored-by: arda-argmax <arda-argmax@users.noreply.github.com>
@msiracusa
Copy link
Collaborator

Thank you for opening this PR and adding support for Stable Diffusion 3!

Two high level topics I think will be important to cover here are:

  • Keeping a unified interface / entry point for model conversion
  • Reducing swift code duplication across the pipelines and model interfaces

Reviewers have been assigned and will provide more detailed feedback

Copy link
Collaborator

@alejandro-isaza alejandro-isaza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main concerns are:

  1. T5Tokenizer.swift it too long, please break out.
  2. Too many new public types, let's try to keep the public interface small.
  3. Some duplication. I know some of it needs a larger refactor, but there are some easy wins here.

public func decode(
_ latents: [MLShapedArray<Float32>],
scaleFactor: Float32,
shiftFactor: Float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, shiftFactor is the only difference between Decoder and DecoderSD3. Let's add the shift to Decoder and default it to 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still stands, let's reuse Decoder instead of introducing DecoderSD3

// MARK: - Configuration files with dynamic lookup

@dynamicMemberLookup
public struct Config {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many types in this file are public, but I don't think they should be.

Config seems more like a general-purpose JSON structure. Please give a better name (AnyJSON?). An alternative is to use Codable structs instead of this generic type.

case tooLong(String)
}

public protocol TokenizingModel {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this protocol? Looks like it's only used once.

}
}

public protocol Tokenizer {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this protocol? It's only used once.


extension TokenLatticeNode {
// This is a reference type because structs can't contain references to the same type
// We could implement NSCopying, but frankly I don't see the point
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this comment is accurate, or relevant. Consider converting to an init method: init(from: TokenLattinceNode)

}
}

public extension Trie {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is too long. Let's move some of these types to their own files. (Also not public)


public extension Trie {
// Only used for testing, could migrate to collection
func get(_ element: any Sequence<T>) -> Node? {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't look like this is being used, please remove.

enum PostProcessorType: String {
case TemplateProcessing
case ByteLevel
case RobertaProcessing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and in other enums: enum cases should start with a lowercase letter. Please review https://www.swift.org/documentation/api-design-guidelines/

}
}

let byteEncoder: Dictionary<UTF8.CodeUnit, String> = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please consider using Unicode.Scalar(_:) or String(cString:), if that is not viable please use a better name for this dictionary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @alejandro-isaza, as a general note, this T5Tokenizer code is actually copied/adapted from swift-transformers. We may want to simply add swift-transformers as a dependency, and we can bring these suggestions over there. Curious your thoughts, otherwise I will be happy to adapt this code further based on your notes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a dependency is a good idea. Ideally it would be swift-tokenizers :)

return noise
}

func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this was copied from Unet.swift, can we refactor instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By refactor do you mean using the same function from the Unet class? Or refactor as in adjusting the Unet class to support MMDiT?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor as in move the common code out into a shared function. For instance a free function that takes both the batch and the models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still pending.

@ZachNagengast
Copy link
Contributor Author

ZachNagengast commented Jul 9, 2024

Quick update here, planning on the following changes:

  1. Bring in swift-transformers as a SPM dependency to remove the need for the reference T5 tokenizer code
  2. Add the python package diffusionkit to the torch2coreml pipeline with helpers to convert the MMDiT and T5 models
  3. Resolve other pending code comments from this review and in the meantime

Estimating a week or so to complete this work and will be ready for another review at that time ⏱️.

@ZachNagengast
Copy link
Contributor Author

@alejandro-isaza @aseemw I've updated this PR with the discussed changes:

  1. T5 Tokenizer logic has been offloaded to swift-transformers as a dependency
  2. The torch2coreml conversion script now uses the DiffusionKit python package to convert the new models from SD3, alongside the old SDXL models resulting in a full output of all needed models with one script.
  3. Update documentation with relevant conversion script examples

E.g. here is a script that will export all models to support 1024x1024 image output with the optional T5 text encoder included:
python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli -o ~/Downloads --sd3-version --convert-text-encoder --convert-vae-decoder --convert-mmdit --include-t5 --latent-w 128 --latent-h 128

I believe this PR is ready for a second review so let me know if you have further comments 🙏.

Worth calling out here as well that I'm also testing on the macOS 15 beta, and it appears to have an issue running these models, whereas macOS 14 does not. What I'm seeing is that each step takes a very long time, with fluctuating memory that seems like it is unloading and reloading the model on each timestep during inference. For an easy way to replicate you can try out HuggingFace Diffusers app on macOS 15, which has already rolled out with these SD3 models.

Copy link
Collaborator

@alejandro-isaza alejandro-isaza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of places where the code could be refactored, but other than that the Swift side looks good to me.

public func decode(
_ latents: [MLShapedArray<Float32>],
scaleFactor: Float32,
shiftFactor: Float32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still stands, let's reuse Decoder instead of introducing DecoderSD3

return noise
}

func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still pending.

@ZachNagengast
Copy link
Contributor Author

@alejandro-isaza Thanks for the note, just pushed updates for these parts.

@atiorh atiorh merged commit c891f43 into apple:main Jul 23, 2024
gdbing added a commit to MochiDiffusion/MochiDiffusion that referenced this pull request Oct 7, 2024
gdbing added a commit to ZachNagengast/MochiDiffusion that referenced this pull request Oct 7, 2024
gdbing added a commit to MochiDiffusion/MochiDiffusion that referenced this pull request Oct 7, 2024
* Add sd3 pipeline to app

* Revert to ml-stable-diffusion main branch

- apple/ml-stable-diffusion#329 merged to main

---------

Co-authored-by: Graham Bing <gdbing@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants