Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Dec 13, 2024
1 parent 5679067 commit 24d79a3
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
Expand Down Expand Up @@ -834,28 +834,31 @@ def load_module(name, value):
return True

init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
scheduler_types = None
if "scheduler" in expected_types:
scheduler_types = []
for scheduler_type in expected_types["scheduler"]:
if isinstance(scheduler_type, enum.EnumMeta):
scheduler_types.extend(list(scheduler_type))
else:
scheduler_types.extend([str(scheduler_type)])
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]

for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
for key in init_dict.keys():
if key not in passed_class_obj:
continue
class_name = passed_class_obj[key].__class__.__name__
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
expected_class_name = (
expected_class_name[4:] if expected_class_name.startswith("Flax") else expected_class_name
)
if key == "scheduler" and scheduler_types is not None and class_name not in scheduler_types:
raise ValueError(f"Expected {scheduler_types} for {key}, got {class_name}.")
elif key != "scheduler" and class_name != expected_class_name:
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")

class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)

_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
if _requires_flow_match and not _is_flow_match:
raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
elif not _requires_flow_match and _is_flow_match:
raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")
elif not _is_valid_type:
raise ValueError(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)

# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
Expand Down

0 comments on commit 24d79a3

Please sign in to comment.