Skip to content

allow models to run with a user-provided dtype map instead of a single dtype #10301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 2, 2025
5 changes: 4 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,10 @@ def save_pretrained(
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
try:
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
except RuntimeError:
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
Comment on lines +717 to +720
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why this is erroring out for this test? With the current fix, for a sharded checkpoint, we might end up saving the entire model multiple times no?

Copy link
Contributor Author

@hlky hlky Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From safetensors, it doesn't allow saving shared tensors without using save_model. Looks like this is why we're using safe_serialization=False in other tests. If it's an issue like this then the similar issue exists without, as in, we couldn't save a sharded checkpoint that has shared tensors with safetensors, IMO a safetensors problem - it should not be so opinionated about what can or cannot be saved with save_file, shared tensors are always minimal and duplicating them would make little difference to the overall size, the documentation on this matter also does not seem to align with our own findings - it mentions that buffers are consumed once we use get_tensor however we have seen that memory is held during the context of safe_open.

https://github.com/huggingface/safetensors/blob/7d5af853631628137a79341ddc5611d18a17f3fe/bindings/python/py_src/safetensors/torch.py#L481-L494

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, safe_serialization=False)

Copy link
Collaborator

@DN6 DN6 Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think in this case it would be better to just skip the test for Unidiffuser (it has very low usage) than change the saving logic for all pipelines.

Copy link
Collaborator

@DN6 DN6 Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took a quick look. Issue is with the text_decoder component that is a transformer PretrainedModel wrapped in a ModelMixin class. So the transformers logic for saving shared tensors is never invoked
https://github.com/huggingface/transformers/blob/ed95493ce05688447d15d9a82d2d70695290ecff/src/transformers/modeling_utils.py#L3464-L3479

We can skip the test and deprecate this pipeline later 👍🏽

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#11194 - reverts the change and the test passes with safe_serialization=False.

else:
torch.save(shard, filepath)

Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]

else:
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
Expand All @@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
Expand All @@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
module_name: compute_module_sizes(
module,
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype,
)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
Expand Down
16 changes: 12 additions & 4 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
torch_dtype (`str` or `torch.dtype`, *optional*):
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):

<Tip warning={true}>
Expand Down Expand Up @@ -703,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
Expand Down Expand Up @@ -950,14 +953,19 @@ def load_module(name, value):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,29 @@ def run_forward(pipe):
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))

def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")

pipe = self.pipeline_class(**components)

specified_key = next(iter(components.keys()))

with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)

for name, component in loaded_pipe.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)


@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
Expand Down
Loading