Skip to content

Commit

Permalink
Fix torch_dtype in Kolors text encoder with transformers v4.49 (h…
Browse files Browse the repository at this point in the history
…uggingface#10816)

* Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49

* Default torch_dtype and warning
  • Loading branch information
hlky authored Feb 24, 2025
1 parent 9c7e205 commit 6f74ef5
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 9 deletions.
6 changes: 5 additions & 1 deletion examples/community/checkpoint_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
device_map = kwargs.pop("device_map", None)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")

alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)

Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,17 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False)

is_legacy_loading = False

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
Expand All @@ -879,6 +879,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
Expand All @@ -702,6 +702,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
Expand Down Expand Up @@ -1826,7 +1832,7 @@ def from_pipe(cls, pipeline, **kwargs):
"""

original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)

# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/kolors/test_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/kolors/test_kolors_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/pag/test_pag_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down

0 comments on commit 6f74ef5

Please sign in to comment.