diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 19ac868cdae0..814547d82be4 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -714,7 +714,10 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + try: + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) + except RuntimeError: + safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"}) else: torch.save(shard, filepath) diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 07da8b5e2e2e..f5b430564ca1 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic loaded_sub_model = passed_class_obj[name] else: + sub_model_dtype = ( + torch_dtype.get(name, torch_dtype.get("default", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype + ) loaded_sub_model = _load_empty_model( library_name=library_name, class_name=class_name, @@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, name=name, - torch_dtype=torch_dtype, + torch_dtype=sub_model_dtype, cached_folder=kwargs.get("cached_folder", None), force_download=kwargs.get("force_download", None), proxies=kwargs.get("proxies", None), @@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic # Obtain a sorted dictionary for mapping the model-level components # to their sizes. module_sizes = { - module_name: compute_module_sizes(module, dtype=torch_dtype)[""] + module_name: compute_module_sizes( + module, + dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype, + )[""] for module_name, module in init_empty_modules.items() if isinstance(module, torch.nn.Module) } diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6a508b130c9d..0df4b477e1b9 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -552,9 +552,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P saved using [`~DiffusionPipeline.save_pretrained`]. - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file - torch_dtype (`str` or `torch.dtype`, *optional*): + torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the - dtype is automatically derived from the model's weights. + dtype is automatically derived from the model's weights. To load submodels with different dtype pass a + `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for + unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default': + torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used. custom_pipeline (`str`, *optional*): @@ -703,7 +706,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 logger.warning( f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." @@ -950,6 +953,11 @@ def load_module(name, value): loaded_sub_model = passed_class_obj[name] else: # load sub model + sub_model_dtype = ( + torch_dtype.get(name, torch_dtype.get("default", torch.float32)) + if isinstance(torch_dtype, dict) + else torch_dtype + ) loaded_sub_model = load_sub_model( library_name=library_name, class_name=class_name, @@ -957,7 +965,7 @@ def load_module(name, value): pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, - torch_dtype=torch_dtype, + torch_dtype=sub_model_dtype, provider=provider, sess_options=sess_options, device_map=current_device_map, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d069def66ecf..cc5008e37292 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2283,6 +2283,29 @@ def run_forward(pipe): self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4)) self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4)) + def test_torch_dtype_dict(self): + components = self.get_dummy_components() + if not components: + self.skipTest("No dummy components defined.") + + pipe = self.pipeline_class(**components) + + specified_key = next(iter(components.keys())) + + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: + pipe.save_pretrained(tmpdirname) + torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} + loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict) + + for name, component in loaded_pipe.components.items(): + if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"): + expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32)) + self.assertEqual( + component.dtype, + expected_dtype, + f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}", + ) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase):