Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch_dtype in Kolors text encoder with transformers v4.49 #10816

Merged
merged 3 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/community/checkpoint_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
device_map = kwargs.pop("device_map", None)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")

alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)

Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,17 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False)

is_legacy_loading = False

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,17 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
Expand All @@ -879,6 +879,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
Expand All @@ -701,6 +701,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)

if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
Expand Down Expand Up @@ -1829,7 +1835,7 @@ def from_pipe(cls, pipeline, **kwargs):
"""

original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)

# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/kolors/test_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/kolors/test_kolors_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down
4 changes: 3 additions & 1 deletion tests/pipelines/pag/test_pag_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")

components = {
Expand Down
Loading