Skip to content

allow models to run with a user-provided dtype map instead of a single dtype #10301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 2, 2025

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Dec 19, 2024

What does this PR do?

Example

import torch
from diffusers import HunyuanVideoPipeline

model_id = "tencent/HunyuanVideo"
pipe = HunyuanVideoPipeline.from_pretrained(model_id, torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16}, revision="refs/pr/18")
pipe.transformer.dtype, pipe.vae.dtype
(torch.bfloat16, torch.float16)

default is used as a default dtype for components that are not specified, otherwise the current default of torch.float32 is used.

Haven't looked at from_pipe case yet and we'll need to add tests but ready for a first review in case there's something missing because it's simpler than expected.

Fixes #10108

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @DN6 @sayakpaul @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Do we not have to handle the typecasts? I think for sharded checkpoints, we might have to.

Comment on lines 557 to 561
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("_", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like _ might be a bit unintuitive. Better to expose full dtype maps or in case partial ones are provided we default to torch.float32 for the rest of the components.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be default? Considering how it will work for integrations, instead of say {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16, 'text_encoder_3': torch.float16} for SD3 and {'transformer': torch.bfloat16, 'text_encoder': torch.float16, 'text_encoder_2': torch.float16} for Flux. Not a big issue because components can be got from cls._get_signature_types().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no strong opinions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now it's renamed to default to be clearer, we can remove later if its not needed.

@hlky
Copy link
Contributor Author

hlky commented Dec 19, 2024

Thanks for the review @sayakpaul. Will look into sharded checkpoints.

@hlky
Copy link
Contributor Author

hlky commented Dec 20, 2024

HunyuanVideo is sharded so I think it's ok.

@hlky hlky added the roadmap Add to current release roadmap label Jan 9, 2025
Copy link
Contributor

github-actions bot commented Feb 2, 2025

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 2, 2025
@hlky hlky added wip and removed stale Issues that haven't received updates labels Feb 2, 2025
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would add a test to PipelineTesterMixin too.

f"Expected `{list(passed_class_obj.keys())}`, got extra `torch_dtype` keys `{extra_keys_dtype}`."
)
if len(extra_keys_obj) > 0:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this warning. I think the expectation of passed class objects is that their dtype is already set and if it isn't it happens at the model level where a dtype=None results in FP32 default.

@DN6
Copy link
Collaborator

DN6 commented Apr 1, 2025

@hlky I think we can merge this is if we just add a test for it to PipelineTesterMixin. cc: @yiyixuxu to take a look as well.

Comment on lines +717 to +720
try:
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
except RuntimeError:
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why this is erroring out for this test? With the current fix, for a sharded checkpoint, we might end up saving the entire model multiple times no?

Copy link
Contributor Author

@hlky hlky Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From safetensors, it doesn't allow saving shared tensors without using save_model. Looks like this is why we're using safe_serialization=False in other tests. If it's an issue like this then the similar issue exists without, as in, we couldn't save a sharded checkpoint that has shared tensors with safetensors, IMO a safetensors problem - it should not be so opinionated about what can or cannot be saved with save_file, shared tensors are always minimal and duplicating them would make little difference to the overall size, the documentation on this matter also does not seem to align with our own findings - it mentions that buffers are consumed once we use get_tensor however we have seen that memory is held during the context of safe_open.

https://github.com/huggingface/safetensors/blob/7d5af853631628137a79341ddc5611d18a17f3fe/bindings/python/py_src/safetensors/torch.py#L481-L494

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)

Copy link
Collaborator

@DN6 DN6 Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think in this case it would be better to just skip the test for Unidiffuser (it has very low usage) than change the saving logic for all pipelines.

Copy link
Collaborator

@DN6 DN6 Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick look. Issue is with the text_decoder component that is a transformer PretrainedModel wrapped in a ModelMixin class. So the transformers logic for saving shared tensors is never invoked
https://github.com/huggingface/transformers/blob/ed95493ce05688447d15d9a82d2d70695290ecff/src/transformers/modeling_utils.py#L3464-L3479

We can skip the test and deprecate this pipeline later 👍🏽

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#11194 - reverts the change and the test passes with safe_serialization=False.

@hlky hlky merged commit d8c617c into huggingface:main Apr 2, 2025
28 of 29 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.34 Apr 2, 2025
This was referenced Apr 2, 2025
jonluca added a commit to weights-ai/diffusers that referenced this pull request Apr 3, 2025
* Raise warning and round down if Wan num_frames is not 4k + 1 (huggingface#11167)

* update

* raise warning and round to nearest multiple of scale factor

* [Docs] Fix environment variables in `installation.md` (huggingface#11179)

* Add `latents_mean` and `latents_std` to `SDXLLongPromptWeightingPipeline` (huggingface#11034)

* Bug fix in LTXImageToVideoPipeline.prepare_latents() when latents is already set (huggingface#10918)

* Bug fix in ltx

* Assume packed latents.

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

* [tests] no hard-coded cuda  (huggingface#11186)

no cuda only

* [WIP] Add Wan Video2Video (huggingface#11053)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* map BACKEND_RESET_MAX_MEMORY_ALLOCATED to reset_peak_memory_stats on XPU (huggingface#11191)

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix autocast (huggingface#11190)

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix: for checking mandatory and optional pipeline components (huggingface#11189)

fix: optional componentes verification on load

* remove unnecessary call to `F.pad` (huggingface#10620)

* rewrite memory count without implicitly using dimensions by @ic-synth

* replace F.pad by built-in padding in Conv3D

* in-place sums to reduce memory allocations

* fixed trailing whitespace

* file reformatted

* in-place sums

* simpler in-place expressions

* removed in-place sum, may affect backward propagation logic

* removed in-place sum, may affect backward propagation logic

* removed in-place sum, may affect backward propagation logic

* reverted change

* allow models to run with a user-provided dtype map instead of a single dtype (huggingface#10301)

* allow models to run with a user-provided dtype map instead of a single dtype

* make style

* Add warning, change `_` to `default`

* make style

* add test

* handle shared tensors

* remove warning

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* [tests] HunyuanDiTControlNetPipeline inference precision issue on XPU (huggingface#11197)

* add xpu part

* fix more cases

* remove some cases

* no canny

* format fix

* Revert `save_model` in ModelMixin save_pretrained and use safe_serialization=False in test (huggingface#11196)

* [docs] `torch_dtype` map (huggingface#11194)

* Fix enable_sequential_cpu_offload in CogView4Pipeline (huggingface#11195)

* Fix enable_sequential_cpu_offload in CogView4Pipeline

* make fix-copies

* SchedulerMixin from_pretrained and ConfigMixin Self type annotation (huggingface#11192)

* Update import_utils.py (huggingface#10329)

added onnxruntime-vitisai for custom build onnxruntime pkg

* Add CacheMixin to Wan and LTX Transformers (huggingface#11187)

* update

* update

* update

* feat: [Community Pipeline] - FaithDiff Stable Diffusion XL Pipeline (huggingface#11188)

* feat: [Community Pipeline] - FaithDiff Stable Diffusion XL Pipeline for Image SR.

* added pipeline

* [Model Card] standardize advanced diffusion training sdxl lora (huggingface#7615)

* model card gen code

* push modelcard creation

* remove optional from params

* add import

* add use_dora check

* correct lora var use in tags

* make style && make quality

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Change KolorsPipeline LoRA Loader to StableDiffusion (huggingface#11198)

Change LoRA Loader to StableDiffusion

Replace the SDXL LoRA Loader Mixin inheritance with the StableDiffusion one

* Update Style Bot workflow (huggingface#11202)

update style bot workflow

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Mark <remarkablemark@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: kakukakujirori <63725741+kakukakujirori@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Fanli Lin <fanli.lin@intel.com>
Co-authored-by: Yao Matrix <matrix.yao@intel.com>
Co-authored-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Eliseu Silva <elismasilva@gmail.com>
Co-authored-by: Bruno Magalhaes <bruno.magalhaes@synthesia.io>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: lakshay sharma <31830611+Lakshaysharma048@users.noreply.github.com>
Co-authored-by: Abhipsha Das <ad6489@nyu.edu>
Co-authored-by: Basile Lewandowski <basile.lewan@gmail.com>
Co-authored-by: célina <hanouticelina@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap wip
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[pipelines] allow models to run with a user-provided dtype map instead of a single dtype
4 participants