-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Add Photon model and pipeline support #12456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit adds support for the Photon image generation model: - PhotonTransformer2DModel: Core transformer architecture - PhotonPipeline: Text-to-image generation pipeline - Attention processor updates for Photon-specific attention mechanism - Conversion script for loading Photon checkpoints - Documentation and tests
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure on this one: I'm saving the VAE weights while they are already available on the Hub (Flux VAE and DC-AE).
Is there a way to avoid storing them and instead look directly for the original ones?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
print(f"✓ Saved VAE to {vae_path}") | ||
|
||
|
||
def download_and_save_text_encoder(output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here for the Text Encoder.
print("✓ Created scheduler config") | ||
|
||
|
||
def download_and_save_vae(vae_type: str, output_path: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it's okay to keep this as is. This way, everything is under the same model repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clean PR! I left some initial feedback for you. LMK if that makes sense.
Also, it would be great to see some samples of Photon!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a couple more comments. Let's also add the pipeline-level tests.
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/> | ||
</div> | ||
|
||
Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @stevhliu for a review on the docs.
return xq_out.reshape(*xq.shape).type_as(xq) | ||
|
||
|
||
class PhotonAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we write it in a fashion similar to
class FluxAttnProcessor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this suggestion - in particular, I think it would be more in line with other diffusers
models implementations to reuse the layers defined in Attention
, such as to_q
/to_k
/to_v
, etc. instead of defining them in PhotonBlock
(e.g. PhotonBlock.img_qkv_proj
), and to keep the entire attention implementation in the PhotonAttnProcessor2_0
class.
Attention
supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention
, you could potentially add it in there or create a new Attention
-style class like Flux does:
class FluxAttention(torch.nn.Module, AttentionModuleMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the change and updated both the conversion script and the checkpoints on the hub.
def __call__( | ||
self, | ||
prompt: Union[str, List[str]] = None, | ||
height: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support passing prompt embeddings too in case users want to supply them precomputed:
prompt_embeds: Optional[torch.FloatTensor] = None, |
default_sample_size = getattr(self.config, "default_sample_size", DEFAULT_RESOLUTION) | ||
height = height or default_sample_size | ||
width = width or default_sample_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer this pattern:
height = height or self.default_sample_size * self.vae_scale_factor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it this way because the model works for two different vae with different scale_factors.
Is it ok to not make it depend of self.vae_scale_factor? It makes it hard to define a default value otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh good point! I think we could make a small utility function in the pipeline class that determines the default resolution given the VAE that's loaded into it? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, way cleaner! I did it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the docs, remember to add it to the toctree
as well!
|
||
### Manual Component Loading | ||
|
||
You can also load components individually: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to demonstrate why you're loading components individually, for example, it could be for quantization
You can also load components individually: | |
Load components individually to ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did an example with quantization as you suggested.
pipe.to("cuda") | ||
``` | ||
|
||
## VAE Variants |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAE section can be removed since its already mentioned in the first paragraph.
|
||
The VAE type is automatically determined from the checkpoint's `model_index.json` configuration. | ||
|
||
## Generation Parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section can also be removed since its safe to assume a user has some background knowledge of using a text-to-image pipeline
return self.out_layer(self.silu(self.in_layer(x))) | ||
|
||
|
||
class QKNorm(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider reusing the QK norm implementation in Attention
; I believe setting qk_norm == "rms_norm"
should be equivalent:
diffusers/src/diffusers/models/attention_processor.py
Lines 205 to 207 in 8abc7ae
elif qk_norm == "rms_norm": | |
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) | |
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it worked thanks!
|
||
# img qkv | ||
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | ||
self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using the layers from the Attention
instance (self.attention
) rather than defining them here would be more idiomatic in diffusers
. See also https://github.com/huggingface/diffusers/pull/12456/files#r2434379626.
|
||
def forward( | ||
self, | ||
img: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be more clear (in the diffusers
context) to adopt the usual naming scheme in FluxTransformerBlock
, WanTransformerBlock
, etc.:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 437 to 444 in 8abc7ae
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
temb: torch.Tensor, | |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: |
so something like img
--> hidden_states
, txt
--> encoder_hidden_states
, vec
--> temb
, pe
--> image_rotary_emb
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely, my bad not spotting it earlier.
attn_shift, attn_scale, attn_gate = mod_attn | ||
mlp_shift, mlp_scale, mlp_gate = mod_mlp | ||
|
||
# Inline attention forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest putting all of the attention implementation logic in PhotonAttnProcessor2_0
(see https://github.com/huggingface/diffusers/pull/12456/files#r2434379626)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, way cleaner.
self._guidance_scale = guidance_scale | ||
|
||
# 2. Encode input prompt | ||
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it would be a little cleaner if text_embeddings
was named prompt_embeds
and uncond_text_embeddings
was named negative_prompt_embeds
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definilty, and the line below was useless.
self.text_encoder = text_encoder | ||
self.tokenizer = tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting these attributes explicitly shouldn't be necessary since the register_modules
call below should handle that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed them thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Left some comments :).
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit adds support for the Photon image generation model:
Some exemples below with the 512 model fine-tuned on the Alchemist dataset and distilled with PAG
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.